update deployment in readme and cli_demo

This commit is contained in:
yangapku
2023-08-29 16:46:15 +08:00
parent 910700571d
commit f1402ce523
7 changed files with 138 additions and 12 deletions

View File

@@ -46,16 +46,29 @@ def _load_model_tokenizer(args):
else:
device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
model.generation_config = GenerationConfig.from_pretrained(
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
if os.path.exists(qconfig_path):
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
use_safetensors=True,
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
)
return model, tokenizer
return model, tokenizer, config
def _clear_screen():
@@ -99,7 +112,7 @@ def main():
history, response = [], ''
model, tokenizer = _load_model_tokenizer(args)
model, tokenizer, config = _load_model_tokenizer(args)
orig_gen_config = deepcopy(model.generation_config)
_clear_screen()
@@ -179,7 +192,7 @@ def main():
# Run chat.
set_seed(seed)
try:
for response in model.chat_stream(tokenizer, query, history=history):
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config):
_clear_screen()
print(f"\nUser: {query}")
print(f"\nQwen-7B: {response}")