update cli_demo.py and web_demo.py

This commit is contained in:
yangapku
2023-09-25 20:29:29 +08:00
parent e0c5e9f7f3
commit eb6e364fe7
2 changed files with 36 additions and 50 deletions

View File

@@ -18,8 +18,12 @@ from transformers.trainer_utils import set_seed
DEFAULT_CKPT_PATH = 'Qwen/Qwen-7B-Chat'
_WELCOME_MSG = '''\
Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help
欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助
Welcome to use Qwen-Chat model, type text to start chat, type :h to show command help.
(欢迎使用 Qwen-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。)
Note: This demo is governed by the original license of Qwen.
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc.
(注本演示受Qwen的许可协议限制。我们强烈建议用户不应传播及不应允许他人传播以下内容包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
'''
_HELP_MSG = '''\
Commands:
@@ -46,23 +50,12 @@ def _load_model_tokenizer(args):
else:
device_map = "auto"
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
if os.path.exists(qconfig_path):
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
use_safetensors=True,
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
@@ -103,7 +96,7 @@ def _get_input() -> str:
def main():
parser = argparse.ArgumentParser(
description='QWen-7B-Chat command-line interactive chat demo.')
description='QWen-Chat command-line interactive chat demo.')
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
@@ -195,7 +188,7 @@ def main():
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config):
_clear_screen()
print(f"\nUser: {query}")
print(f"\nQwen-7B: {response}")
print(f"\nQwen-Chat: {response}")
except KeyboardInterrupt:
print('[WARNING] Generation interrupted')
continue

View File

@@ -47,23 +47,12 @@ def _load_model_tokenizer(args):
else:
device_map = "auto"
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
if os.path.exists(qconfig_path):
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
use_safetensors=True,
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
@@ -133,7 +122,7 @@ def _launch_demo(args, model, tokenizer, config):
print(f"History: {_task_history}")
_task_history.append((_query, full_response))
print(f"Qwen-7B-Chat: {_parse_text(full_response)}")
print(f"Qwen-Chat: {_parse_text(full_response)}")
def regenerate(_chatbot, _task_history):
if not _task_history:
@@ -156,21 +145,25 @@ def _launch_demo(args, model, tokenizer, config):
with gr.Blocks() as demo:
gr.Markdown("""\
<p align="center"><img src="https://modelscope.cn/api/v1/models/qwen/Qwen-7B-Chat/repo?
Revision=master&FilePath=assets/logo.jpeg&View=true" style="height: 80px"/><p>""")
gr.Markdown("""<center><font size=8>Qwen-7B-Chat Bot</center>""")
<p align="center"><img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/logo_qwen.jpg" style="height: 80px"/><p>""")
gr.Markdown("""<center><font size=8>Qwen-Chat Bot</center>""")
gr.Markdown(
"""\
<center><font size=3>This WebUI is based on Qwen-7B-Chat, developed by Alibaba Cloud. \
(本WebUI基于Qwen-7B-Chat打造实现聊天机器人功能。)</center>""")
<center><font size=3>This WebUI is based on Qwen-Chat, developed by Alibaba Cloud. \
(本WebUI基于Qwen-Chat打造实现聊天机器人功能。)</center>""")
gr.Markdown("""\
<center><font size=4>Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a>
| <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp
<center><font size=4>
Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a> |
<a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp
Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a> |
<a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp
&nbsp<a href="https://github.com/QwenLM/Qwen-7B">Github</a></center>""")
Qwen-14B <a href="https://modelscope.cn/models/qwen/Qwen-14B/summary">🤖 </a> |
<a href="https://huggingface.co/Qwen/Qwen-14B">🤗</a>&nbsp
Qwen-14B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary">🤖 </a> |
<a href="https://huggingface.co/Qwen/Qwen-14B-Chat">🤗</a>&nbsp
&nbsp<a href="https://github.com/QwenLM/Qwen">Github</a></center>""")
chatbot = gr.Chatbot(label='Qwen-7B-Chat', elem_classes="control-height")
chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height")
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
@@ -185,10 +178,10 @@ Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
gr.Markdown("""\
<font size=2>Note: This demo is governed by the original license of Qwen-7B. \
<font size=2>Note: This demo is governed by the original license of Qwen. \
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
including hate speech, violence, pornography, deception, etc. \
(注本演示受Qwen-7B的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
(注本演示受Qwen的许可协议限制。我们强烈建议用户不应传播及不应允许他人传播以下内容\
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
demo.queue().launch(