mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
update readme about batch inference
This commit is contained in:
27
README.md
27
README.md
@@ -371,18 +371,26 @@ you can use the dequantization operation to convert the int8 key/value back to t
|
|||||||
|
|
||||||
## Batch Inference
|
## Batch Inference
|
||||||
Qwen supports batch inference. With flash-attention enabled, using batch inference can bring a 40% speedup. The example code is shown below:
|
Qwen supports batch inference. With flash-attention enabled, using batch inference can bring a 40% speedup. The example code is shown below:
|
||||||
```
|
```python
|
||||||
import torch
|
import torch
|
||||||
from tokenization_qwen import QWenTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from modeling_qwen import QWenLMHeadModel
|
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
|
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left')
|
'./',
|
||||||
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval()
|
pad_token='<|extra_0|>',
|
||||||
model.generation_config = GenerationConfig.from_pretrained('./')
|
eos_token='<|endoftext|>',
|
||||||
stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer)
|
padding_side='left',
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
'./',
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
).eval()
|
||||||
|
model.generation_config = GenerationConfig.from_pretrained('./', pad_token_id=tokenizer.pad_token_id)
|
||||||
|
|
||||||
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
|
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
|
||||||
batch_raw_text = []
|
batch_raw_text = []
|
||||||
@@ -400,7 +408,6 @@ batch_input_ids = tokenizer(batch_raw_text, padding='longest')
|
|||||||
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
|
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
|
||||||
batch_out_ids = model.generate(
|
batch_out_ids = model.generate(
|
||||||
batch_input_ids,
|
batch_input_ids,
|
||||||
stop_words_ids=stop_words_ids,
|
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
generation_config=model.generation_config
|
generation_config=model.generation_config
|
||||||
)
|
)
|
||||||
@@ -411,7 +418,7 @@ batch_response = [
|
|||||||
batch_out_ids[i][padding_lens[i]:],
|
batch_out_ids[i][padding_lens[i]:],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
raw_text_len=len(batch_raw_text[i]),
|
raw_text_len=len(batch_raw_text[i]),
|
||||||
context_length=batch_input_ids[i].size(0),
|
context_length=(batch_input_ids[i].size(0)-padding_lens[i]),
|
||||||
chat_format="chatml",
|
chat_format="chatml",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
errors='replace'
|
errors='replace'
|
||||||
|
|||||||
27
README_CN.md
27
README_CN.md
@@ -359,18 +359,26 @@ model = AutoModelForCausalLM.from_pretrained(
|
|||||||
|
|
||||||
## Batch推理
|
## Batch推理
|
||||||
千问支持batch批量推理。在开启flash-attention的状态下,使用batch推理可以约40%的提速。示例代码如下所示:
|
千问支持batch批量推理。在开启flash-attention的状态下,使用batch推理可以约40%的提速。示例代码如下所示:
|
||||||
```
|
```python
|
||||||
import torch
|
import torch
|
||||||
from tokenization_qwen import QWenTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from modeling_qwen import QWenLMHeadModel
|
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
|
from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer = QWenTokenizer.from_pretrained('./', pad_token='<|extra_0|>', eos_token='<|endoftext|>', padding_side='left')
|
'./',
|
||||||
model = QWenLMHeadModel.from_pretrained('./', device_map="auto").eval()
|
pad_token='<|extra_0|>',
|
||||||
model.generation_config = GenerationConfig.from_pretrained('./')
|
eos_token='<|endoftext|>',
|
||||||
stop_words_ids = get_stop_words_ids(model.generation_config.chat_format, tokenizer)
|
padding_side='left',
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
'./',
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
).eval()
|
||||||
|
model.generation_config = GenerationConfig.from_pretrained('./', pad_token_id=tokenizer.pad_token_id)
|
||||||
|
|
||||||
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
|
all_raw_text = ["我想听你说爱我。", "今天我想吃点啥,甜甜的,推荐下", "我马上迟到了,怎么做才能不迟到"]
|
||||||
batch_raw_text = []
|
batch_raw_text = []
|
||||||
@@ -388,7 +396,6 @@ batch_input_ids = tokenizer(batch_raw_text, padding='longest')
|
|||||||
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
|
batch_input_ids = torch.LongTensor(batch_input_ids['input_ids']).to(model.device)
|
||||||
batch_out_ids = model.generate(
|
batch_out_ids = model.generate(
|
||||||
batch_input_ids,
|
batch_input_ids,
|
||||||
stop_words_ids=stop_words_ids,
|
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
generation_config=model.generation_config
|
generation_config=model.generation_config
|
||||||
)
|
)
|
||||||
@@ -399,7 +406,7 @@ batch_response = [
|
|||||||
batch_out_ids[i][padding_lens[i]:],
|
batch_out_ids[i][padding_lens[i]:],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
raw_text_len=len(batch_raw_text[i]),
|
raw_text_len=len(batch_raw_text[i]),
|
||||||
context_length=batch_input_ids[i].size(0),
|
context_length=(batch_input_ids[i].size(0)-padding_lens[i]),
|
||||||
chat_format="chatml",
|
chat_format="chatml",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
errors='replace'
|
errors='replace'
|
||||||
|
|||||||
Reference in New Issue
Block a user