diff --git a/README.md b/README.md index dec6cf4..98bab0a 100644 --- a/README.md +++ b/README.md @@ -257,7 +257,7 @@ python cli_demo.py We provide code for users to build a web UI demo (thanks to @wysiad). Before you start, make sure you install the following packages: ``` -pip install gradio mdtex2html +pip install -r requirements_web_demo.txt ``` Then run the command below and click on the generated link: diff --git a/README_CN.md b/README_CN.md index 78077a7..a034c90 100644 --- a/README_CN.md +++ b/README_CN.md @@ -259,7 +259,7 @@ python cli_demo.py 我们提供了Web UI的demo供用户使用 (感谢 @wysiad 支持)。在开始前,确保已经安装如下代码库: ``` -pip install gradio mdtex2html +pip install -r requirements_web_demo.txt ``` 随后运行如下命令,并点击生成链接: diff --git a/README_JA.md b/README_JA.md index bc51ab3..069d6da 100644 --- a/README_JA.md +++ b/README_JA.md @@ -264,7 +264,7 @@ python cli_demo.py ウェブUIデモを構築するためのコードを提供します(@wysiadに感謝)。始める前に、以下のパッケージがインストールされていることを確認してください: ``` -pip install gradio mdtex2html +pip install -r requirements_web_demo.txt ``` そして、以下のコマンドを実行し、生成されたリンクをクリックする: diff --git a/cli_demo.py b/cli_demo.py index 52a3a63..4a095ad 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -38,7 +38,7 @@ Commands: def _load_model_tokenizer(args): tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True, + args.checkpoint_path, trust_remote_code=True, resume_download=True, ) if args.cpu_only: @@ -50,9 +50,11 @@ def _load_model_tokenizer(args): args.checkpoint_path, device_map=device_map, trust_remote_code=True, + resume_download=True, ).eval() - model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True) - + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) return model, tokenizer diff --git a/web_demo.py b/web_demo.py index 5b3a2af..89cb28e 100755 --- a/web_demo.py +++ b/web_demo.py @@ -1,52 +1,60 @@ -#!/usr/bin/env python3 +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A simple web interactive chat demo based on gradio.""" + +from argparse import ArgumentParser -from transformers import AutoTokenizer import gradio as gr import mdtex2html from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig -from argparse import ArgumentParser -import sys -print("Call args:" + str(sys.argv)) -parser = ArgumentParser() -parser.add_argument("--share", action="store_true", default=False) -parser.add_argument("--inbrowser", action="store_true", default=False) -parser.add_argument("--server_port", type=int, default=80) -parser.add_argument("--server_name", type=str, default="0.0.0.0") -parser.add_argument("--exit", action="store_true", default=False) -parser.add_argument("--model_revision", type=str, default="") -args = parser.parse_args(sys.argv[1:]) -print("Args:" + str(args)) +DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat' -tokenizer = AutoTokenizer.from_pretrained( - "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True -) -model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen-7B-Chat", - device_map="auto", - trust_remote_code=True, - resume_download=True, - **{"revision": args.model_revision} - if args.model_revision is not None - and args.model_revision != "" - and args.model_revision != "None" - else {}, -).eval() +def _get_args(): + parser = ArgumentParser() + parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + help="Checkpoint name or path, default to %(default)r") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") -model.generation_config = GenerationConfig.from_pretrained( - "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True -) + parser.add_argument("--share", action="store_true", default=False, + help="Create a publicly shareable link for the interface.") + parser.add_argument("--inbrowser", action="store_true", default=False, + help="Automatically launch the interface in a new tab on the default browser.") + parser.add_argument("--server-port", type=int, default=8000, + help="Demo server port.") + parser.add_argument("--server-name", type=str, default="127.0.0.1", + help="Demo server name.") -if "exit" in args: - if args.exit: - sys.exit(0) + args = parser.parse_args() + return args + + +def _load_model_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + if args.cpu_only: + device_map = "cpu" else: - del args.exit + device_map = "auto" -if "model_revision" in args: - del args.model_revision + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + return model, tokenizer def postprocess(self, y): @@ -54,7 +62,7 @@ def postprocess(self, y): return [] for i, (message, response) in enumerate(y): y[i] = ( - None if message is None else mdtex2html.convert((message)), + None if message is None else mdtex2html.convert(message), None if response is None else mdtex2html.convert(response), ) return y @@ -63,7 +71,7 @@ def postprocess(self, y): gr.Chatbot.postprocess = postprocess -def parse_text(text): +def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 @@ -78,7 +86,7 @@ def parse_text(text): else: if i > 0: if count % 2 == 1: - line = line.replace("`", "\`") + line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") @@ -95,70 +103,89 @@ def parse_text(text): return text -task_history = [] +def _launch_demo(args, model, tokenizer): + task_history = [] + def predict(_query, _chatbot): + print("User: " + _parse_text(_query)) + _chatbot.append((_parse_text(_query), "")) + full_response = "" -def predict(query, chatbot): - print("User: " + parse_text(query)) - chatbot.append((parse_text(query), "")) - fullResponse = "" + for response in model.chat_stream(tokenizer, _query, history=task_history): + _chatbot[-1] = (_parse_text(_query), _parse_text(response)) - for response in model.chat_stream(tokenizer, query, history=task_history): - chatbot[-1] = (parse_text(query), parse_text(response)) + yield _chatbot + full_response = _parse_text(response) - yield chatbot - fullResponse = parse_text(response) + task_history.append((_query, full_response)) + print("Qwen-7B-Chat: " + _parse_text(full_response)) - task_history.append((query, fullResponse)) - print("Qwen-7B-Chat: " + parse_text(fullResponse)) + def regenerate(_chatbot): + if not task_history: + yield _chatbot + return + item = task_history.pop(-1) + _chatbot.pop(-1) + yield from predict(item[0], _chatbot) + def reset_user_input(): + return gr.update(value="") -def regenerate(chatbot): - if not task_history: - yield chatbot - return - item = task_history.pop(-1) - chatbot.pop(-1) - yield from predict(item[0], chatbot) + def reset_state(): + task_history.clear() + return [] + with gr.Blocks() as demo: + gr.Markdown("""\ +

""") + gr.Markdown("""

""") - gr.Markdown("""