mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
update web demo
This commit is contained in:
@@ -34,7 +34,6 @@ Qwen-7Bは、アリババクラウドが提唱する大規模言語モデルシ
|
||||
## ニュース
|
||||
|
||||
* 2023.8.21 Qwen-7B-Chat 用 Int4 量子化モデル(**Qwen-7B-Chat-Int4**)をリリースしました。メモリコストは低いが、推論速度は向上している。また、ベンチマーク評価において大きな性能劣化はありません。
|
||||
|
||||
* 2023.8.3 Qwen-7B と Qwen-7B-Chat を ModelScope と Hugging Face で公開。また、トレーニングの詳細やモデルの性能など、モデルの詳細についてはテクニカルメモを提供しています。
|
||||
|
||||
## パフォーマンス
|
||||
|
||||
38
web_demo.py
38
web_demo.py
@@ -4,7 +4,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""A simple web interactive chat demo based on gradio."""
|
||||
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import gradio as gr
|
||||
@@ -44,17 +44,29 @@ def _load_model_tokenizer(args):
|
||||
else:
|
||||
device_map = "auto"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
).eval()
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
|
||||
if os.path.exists(qconfig_path):
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
use_safetensors=True,
|
||||
).eval()
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
).eval()
|
||||
|
||||
config = GenerationConfig.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
return model, tokenizer, config
|
||||
|
||||
|
||||
def postprocess(self, y):
|
||||
@@ -103,14 +115,14 @@ def _parse_text(text):
|
||||
return text
|
||||
|
||||
|
||||
def _launch_demo(args, model, tokenizer):
|
||||
def _launch_demo(args, model, tokenizer, config):
|
||||
|
||||
def predict(_query, _chatbot, _task_history):
|
||||
print(f"User: {_parse_text(_query)}")
|
||||
_chatbot.append((_parse_text(_query), ""))
|
||||
full_response = ""
|
||||
|
||||
for response in model.chat_stream(tokenizer, _query, history=_task_history):
|
||||
for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
|
||||
_chatbot[-1] = (_parse_text(_query), _parse_text(response))
|
||||
|
||||
yield _chatbot
|
||||
@@ -183,9 +195,9 @@ including hate speech, violence, pornography, deception, etc. \
|
||||
def main():
|
||||
args = _get_args()
|
||||
|
||||
model, tokenizer = _load_model_tokenizer(args)
|
||||
model, tokenizer, config = _load_model_tokenizer(args)
|
||||
|
||||
_launch_demo(args, model, tokenizer)
|
||||
_launch_demo(args, model, tokenizer, config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user