Update demo.py

This commit is contained in:
Junyang Lin
2023-08-03 13:43:07 +08:00
committed by GitHub
parent f88d73dc03
commit 1f96081b74

View File

@@ -5,8 +5,7 @@
import torch import torch
import argparse import argparse
from pathlib import Path from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
@@ -68,7 +67,7 @@ def demo_qwen_chat(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test HF checkpoint.") parser = argparse.ArgumentParser(description="Test HF checkpoint.")
parser.add_argument("-c", "--checkpoint-path", type=Path, help="Checkpoint path") parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path")
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
parser.add_argument("--gpu", type=int, default=0, help="gpu id") parser.add_argument("--gpu", type=int, default=0, help="gpu id")
@@ -78,4 +77,4 @@ if __name__ == "__main__":
if 'chat' in args.checkpoint_path.lower(): if 'chat' in args.checkpoint_path.lower():
demo_qwen_chat(args) demo_qwen_chat(args)
else: else:
demo_qwen_pretrain(args) demo_qwen_pretrain(args)