mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
update cache GC in demo and add vocab expansion example
This commit is contained in:
@@ -21,12 +21,21 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
|
||||
def _gc(forced: bool = False):
|
||||
global args
|
||||
if args.disable_gc and not forced:
|
||||
return
|
||||
|
||||
import gc
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
_gc(forced=True)
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@@ -392,6 +401,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
**gen_kwargs
|
||||
)
|
||||
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
||||
_gc()
|
||||
|
||||
response = trim_stop_words(response, stop_words)
|
||||
if request.functions:
|
||||
choice_data = parse_response(response)
|
||||
@@ -453,6 +464,8 @@ async def predict(
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
yield "[DONE]"
|
||||
|
||||
_gc()
|
||||
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
@@ -476,6 +489,8 @@ def _get_args():
|
||||
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
|
||||
" If you want other computers to access your server, use 0.0.0.0 instead.",
|
||||
)
|
||||
parser.add_argument("--disable-gc", action="store_true",
|
||||
help="Disable GC after each response generated.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
Reference in New Issue
Block a user