update cache GC in demo and add vocab expansion example

This commit is contained in:
yangapku
2023-10-09 14:59:57 +08:00
parent 343017c4ce
commit cbfaada8de
8 changed files with 530 additions and 9 deletions

View File

@@ -21,12 +21,21 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
def _gc(forced: bool = False):
global args
if args.disable_gc and not forced:
return
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
_gc(forced=True)
app = FastAPI(lifespan=lifespan)
@@ -392,6 +401,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
**gen_kwargs
)
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
_gc()
response = trim_stop_words(response, stop_words)
if request.functions:
choice_data = parse_response(response)
@@ -453,6 +464,8 @@ async def predict(
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield "[DONE]"
_gc()
def _get_args():
parser = ArgumentParser()
@@ -476,6 +489,8 @@ def _get_args():
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
" If you want other computers to access your server, use 0.0.0.0 instead.",
)
parser.add_argument("--disable-gc", action="store_true",
help="Disable GC after each response generated.")
args = parser.parse_args()
return args