Files
Qwen/demo.py
2023-08-04 15:45:27 +08:00

84 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.trainer_utils import set_seed
def _load_model_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True,
)
print("load tokenizer")
if args.cpu_only:
device_map = "cpu"
max_memory = None
else:
device_map = "auto"
max_memory_str = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory_str for i in range(n_gpus)}
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
max_memory=max_memory,
trust_remote_code=True,
).eval()
return model, tokenizer
def demo_qwen_pretrain(args):
model, tokenizer = _load_model_tokenizer(args)
inputs = tokenizer(
"蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
return_tensors="pt",
)
inputs = inputs.to(model.device)
pred = model.generate(inputs=inputs["input_ids"])
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
def demo_qwen_chat(args):
model, tokenizer = _load_model_tokenizer(args)
queries = [
"请问把大象关冰箱总共要几步?",
"1+3=?",
"请将下面这句话翻译为英文:在哪里跌倒就在哪里趴着",
]
history = None
for turn_idx, query in enumerate(queries, start=1):
response, history = model.chat(
tokenizer,
query,
history=history,
)
print(f"===== Turn {turn_idx} ====")
print("Query:", query, end="\n")
print("Response:", response, end="\n")
def main():
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path")
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
args = parser.parse_args()
set_seed(args.seed)
if "chat" in args.checkpoint_path.lower():
demo_qwen_chat(args)
else:
demo_qwen_pretrain(args)
if __name__ == "__main__":
main()