mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
openai_api: support temperature=0
This commit is contained in:
@@ -321,7 +321,7 @@ def parse_response(response):
|
|||||||
|
|
||||||
|
|
||||||
# completion mode, not chat mode
|
# completion mode, not chat mode
|
||||||
def text_complete_last_message(history, stop_words_ids):
|
def text_complete_last_message(history, stop_words_ids, gen_kwargs):
|
||||||
im_start = "<|im_start|>"
|
im_start = "<|im_start|>"
|
||||||
im_end = "<|im_end|>"
|
im_end = "<|im_end|>"
|
||||||
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
|
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
|
||||||
@@ -339,7 +339,7 @@ def text_complete_last_message(history, stop_words_ids):
|
|||||||
stop_words_ids = _stop_words_ids
|
stop_words_ids = _stop_words_ids
|
||||||
|
|
||||||
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
|
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
|
||||||
output = model.generate(input_ids, stop_words_ids=stop_words_ids).tolist()[0]
|
output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
|
||||||
output = tokenizer.decode(output, errors="ignore")
|
output = tokenizer.decode(output, errors="ignore")
|
||||||
assert output.startswith(prompt)
|
assert output.startswith(prompt)
|
||||||
output = output[len(prompt) :]
|
output = output[len(prompt) :]
|
||||||
@@ -352,6 +352,16 @@ def text_complete_last_message(history, stop_words_ids):
|
|||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
global model, tokenizer
|
global model, tokenizer
|
||||||
|
|
||||||
|
gen_kwargs = {}
|
||||||
|
if request.temperature is not None:
|
||||||
|
if request.temperature < 0.01:
|
||||||
|
gen_kwargs['top_k'] = 1 # greedy decoding
|
||||||
|
else:
|
||||||
|
# Not recommended. Please tune top_p instead.
|
||||||
|
gen_kwargs['temperature'] = request.temperature
|
||||||
|
if request.top_p is not None:
|
||||||
|
gen_kwargs['top_p'] = request.top_p
|
||||||
|
|
||||||
stop_words = add_extra_stop_words(request.stop)
|
stop_words = add_extra_stop_words(request.stop)
|
||||||
if request.functions:
|
if request.functions:
|
||||||
stop_words = stop_words or []
|
stop_words = stop_words or []
|
||||||
@@ -366,12 +376,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
detail="Invalid request: Function calling is not yet implemented for stream mode.",
|
detail="Invalid request: Function calling is not yet implemented for stream mode.",
|
||||||
)
|
)
|
||||||
generate = predict(query, history, request.model, stop_words)
|
generate = predict(query, history, request.model, stop_words, gen_kwargs)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
|
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
|
||||||
if query is _TEXT_COMPLETION_CMD:
|
if query is _TEXT_COMPLETION_CMD:
|
||||||
response = text_complete_last_message(history, stop_words_ids=stop_words_ids)
|
response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
|
||||||
else:
|
else:
|
||||||
response, _ = model.chat(
|
response, _ = model.chat(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -379,6 +389,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
history=history,
|
history=history,
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
append_history=False,
|
append_history=False,
|
||||||
|
**gen_kwargs
|
||||||
)
|
)
|
||||||
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
||||||
response = trim_stop_words(response, stop_words)
|
response = trim_stop_words(response, stop_words)
|
||||||
@@ -396,7 +407,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
|
|
||||||
|
|
||||||
async def predict(
|
async def predict(
|
||||||
query: str, history: List[List[str]], model_id: str, stop_words: List[str]
|
query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict,
|
||||||
):
|
):
|
||||||
global model, tokenizer
|
global model, tokenizer
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
@@ -416,7 +427,7 @@ async def predict(
|
|||||||
detail="Invalid request: custom stop words are not yet supported for stream mode.",
|
detail="Invalid request: custom stop words are not yet supported for stream mode.",
|
||||||
)
|
)
|
||||||
response_generator = model.chat_stream(
|
response_generator = model.chat_stream(
|
||||||
tokenizer, query, history=history, stop_words_ids=stop_words_ids
|
tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
|
||||||
)
|
)
|
||||||
for new_response in response_generator:
|
for new_response in response_generator:
|
||||||
if len(new_response) == current_length:
|
if len(new_response) == current_length:
|
||||||
|
|||||||
Reference in New Issue
Block a user