mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
Update README_CN.md, upate batch infer
This commit is contained in:
33
README_CN.md
33
README_CN.md
@@ -364,15 +364,16 @@ import torch
|
|||||||
from tokenization_qwen import QWenTokenizer
|
from tokenization_qwen import QWenTokenizer
|
||||||
from modeling_qwen import QWenLMHeadModel
|
from modeling_qwen import QWenLMHeadModel
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
from qwen_generation_utils import make_context
|
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
|
||||||
|
|
||||||
|
|
||||||
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left')
|
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left')
|
||||||
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval()
|
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval()
|
||||||
model.generation_config = GenerationConfig.from_pretrained('./')
|
model.generation_config = GenerationConfig.from_pretrained('./')
|
||||||
|
stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer)
|
||||||
|
|
||||||
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
|
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
|
||||||
batch_question = []
|
batch_raw_text = []
|
||||||
for q in all_raw_text:
|
for q in all_raw_text:
|
||||||
raw_text, _ = make_context(
|
raw_text, _ = make_context(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -381,17 +382,29 @@ for q in all_raw_text:
|
|||||||
max_window_size=model.generation_config.max_window_size,
|
max_window_size=model.generation_config.max_window_size,
|
||||||
chat_format=model.generation_config.chat_format,
|
chat_format=model.generation_config.chat_format,
|
||||||
)
|
)
|
||||||
batch_question.append(raw_text)
|
batch_raw_text.append(raw_text)
|
||||||
|
|
||||||
batch_input_ids = tokenizer(batch_question, padding='longest')
|
batch_input_ids = tokenizer(batch_raw_text, padding='longest')
|
||||||
print(batch_input_ids)
|
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
|
||||||
|
|
||||||
batch_input_ids1 = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
|
|
||||||
batch_out_ids = model.generate(
|
batch_out_ids = model.generate(
|
||||||
input_ids=batch_input_ids1
|
batch_input_ids,
|
||||||
,return_dict_in_generate=False
|
stop_words_ids=stop_words_ids,
|
||||||
|
return_dict_in_generate=False,
|
||||||
|
generation_config=model.generation_config
|
||||||
)
|
)
|
||||||
batch_response = [tokenizer.decode(o, skip_special_tokens=True) for o in batch_out_ids]
|
padding_lens = [batch_input_ids[i].eq(tokenizer.pad_token_id).sum().item() for i in range(batch_input_ids.size(0))]
|
||||||
|
|
||||||
|
batch_response = [
|
||||||
|
decode_tokens(
|
||||||
|
batch_out_ids[i][padding_lens[i]:],
|
||||||
|
tokenizer,
|
||||||
|
raw_text_len=len(batch_raw_text[i]),
|
||||||
|
context_length=batch_input_ids[i].size(0),
|
||||||
|
chat_format="chatml",
|
||||||
|
verbose=False,
|
||||||
|
errors='replace'
|
||||||
|
) for i in range(len(all_raw_text))
|
||||||
|
]
|
||||||
print(batch_response)
|
print(batch_response)
|
||||||
|
|
||||||
response, _ = model.chat(tokenizer, "我想听你说爱我。", history=None)
|
response, _ = model.chat(tokenizer, "我想听你说爱我。", history=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user