mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-21 00:45:48 +08:00
update cli_demo.py and web_demo.py
This commit is contained in:
35
cli_demo.py
35
cli_demo.py
@@ -18,8 +18,12 @@ from transformers.trainer_utils import set_seed
|
||||
DEFAULT_CKPT_PATH = 'Qwen/Qwen-7B-Chat'
|
||||
|
||||
_WELCOME_MSG = '''\
|
||||
Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help
|
||||
欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助
|
||||
Welcome to use Qwen-Chat model, type text to start chat, type :h to show command help.
|
||||
(欢迎使用 Qwen-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。)
|
||||
|
||||
Note: This demo is governed by the original license of Qwen.
|
||||
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc.
|
||||
(注:本演示受Qwen的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
|
||||
'''
|
||||
_HELP_MSG = '''\
|
||||
Commands:
|
||||
@@ -46,23 +50,12 @@ def _load_model_tokenizer(args):
|
||||
else:
|
||||
device_map = "auto"
|
||||
|
||||
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
|
||||
if os.path.exists(qconfig_path):
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
use_safetensors=True,
|
||||
).eval()
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
).eval()
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
).eval()
|
||||
|
||||
config = GenerationConfig.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
||||
@@ -103,7 +96,7 @@ def _get_input() -> str:
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='QWen-7B-Chat command-line interactive chat demo.')
|
||||
description='QWen-Chat command-line interactive chat demo.')
|
||||
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
|
||||
help="Checkpoint name or path, default to %(default)r")
|
||||
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
|
||||
@@ -195,7 +188,7 @@ def main():
|
||||
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config):
|
||||
_clear_screen()
|
||||
print(f"\nUser: {query}")
|
||||
print(f"\nQwen-7B: {response}")
|
||||
print(f"\nQwen-Chat: {response}")
|
||||
except KeyboardInterrupt:
|
||||
print('[WARNING] Generation interrupted')
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user