mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
openai_api: support temperature=0
This commit is contained in:
@@ -321,7 +321,7 @@ def parse_response(response):
|
||||
|
||||
|
||||
# completion mode, not chat mode
|
||||
def text_complete_last_message(history, stop_words_ids):
|
||||
def text_complete_last_message(history, stop_words_ids, gen_kwargs):
|
||||
im_start = "<|im_start|>"
|
||||
im_end = "<|im_end|>"
|
||||
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
|
||||
@@ -339,7 +339,7 @@ def text_complete_last_message(history, stop_words_ids):
|
||||
stop_words_ids = _stop_words_ids
|
||||
|
||||
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
|
||||
output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0]
|
||||
output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
|
||||
output = tokenizer.decode(output, errors="ignore")
|
||||
assert output.startswith(prompt)
|
||||
output = output[len(prompt) :]
|
||||
@@ -352,6 +352,16 @@ def text_complete_last_message(history, stop_words_ids):
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer
|
||||
|
||||
gen_kwargs = {}
|
||||
if request.temperature is not None:
|
||||
if request.temperature < 0.01:
|
||||
gen_kwargs['top_k'] = 1 # greedy decoding
|
||||
else:
|
||||
# Not recommended. Please tune top_p instead.
|
||||
gen_kwargs['temperature'] = request.temperature
|
||||
if request.top_p is not None:
|
||||
gen_kwargs['top_p'] = request.top_p
|
||||
|
||||
stop_words = add_extra_stop_words(request.stop)
|
||||
if request.functions:
|
||||
stop_words = stop_words or []
|
||||
@@ -366,12 +376,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
status_code=400,
|
||||
detail="Invalid request: Function calling is not yet implemented for stream mode.",
|
||||
)
|
||||
generate = predict(query, history, request.model, stop_words)
|
||||
generate = predict(query, history, request.model, stop_words, gen_kwargs)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
|
||||
if query is _TEXT_COMPLETION_CMD:
|
||||
response = text_complete_last_message(history, stop_words_ids=stop_words_ids)
|
||||
response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
|
||||
else:
|
||||
response, _ = model.chat(
|
||||
tokenizer,
|
||||
@@ -379,6 +389,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
history=history,
|
||||
stop_words_ids=stop_words_ids,
|
||||
append_history=False,
|
||||
**gen_kwargs
|
||||
)
|
||||
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
||||
response = trim_stop_words(response, stop_words)
|
||||
@@ -396,7 +407,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
|
||||
|
||||
async def predict(
|
||||
query: str, history: List[List[str]], model_id: str, stop_words: List[str]
|
||||
query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict,
|
||||
):
|
||||
global model, tokenizer
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
@@ -416,7 +427,7 @@ async def predict(
|
||||
detail="Invalid request: custom stop words are not yet supported for stream mode.",
|
||||
)
|
||||
response_generator = model.chat_stream(
|
||||
tokenizer, query, history=history, stop_words_ids=stop_words_ids
|
||||
tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
|
||||
)
|
||||
for new_response in response_generator:
|
||||
if len(new_response) == current_length:
|
||||
|
||||
Reference in New Issue
Block a user