mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-21 00:45:48 +08:00
Update demo.py
This commit is contained in:
7
demo.py
7
demo.py
@@ -5,8 +5,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +67,7 @@ def demo_qwen_chat(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
|
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
|
||||||
parser.add_argument("-c", "--checkpoint-path", type=Path, help="Checkpoint path")
|
parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path")
|
||||||
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
|
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
|
||||||
parser.add_argument("--gpu", type=int, default=0, help="gpu id")
|
parser.add_argument("--gpu", type=int, default=0, help="gpu id")
|
||||||
|
|
||||||
@@ -78,4 +77,4 @@ if __name__ == "__main__":
|
|||||||
if 'chat' in args.checkpoint_path.lower():
|
if 'chat' in args.checkpoint_path.lower():
|
||||||
demo_qwen_chat(args)
|
demo_qwen_chat(args)
|
||||||
else:
|
else:
|
||||||
demo_qwen_pretrain(args)
|
demo_qwen_pretrain(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user