mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 16:35:47 +08:00
add stop word on openai api ChatCompletion
This commit is contained in:
@@ -319,6 +319,7 @@ for chunk in openai.ChatCompletion.create(
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
stream=True
|
||||
# Specifying stop words in streaming output format is not yet supported and is under development.
|
||||
):
|
||||
if hasattr(chunk.choices[0].delta, "content"):
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
@@ -329,7 +330,8 @@ response = openai.ChatCompletion.create(
|
||||
messages=[
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
stream=False
|
||||
stream=False,
|
||||
stop=[] # You can add custom stop words here, e.g., stop=["Observation:"] for ReAct prompting.
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
@@ -323,6 +323,7 @@ for chunk in openai.ChatCompletion.create(
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
stream=True
|
||||
# 流式输出的自定义stopwords功能尚未支持,正在开发中
|
||||
):
|
||||
if hasattr(chunk.choices[0].delta, "content"):
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
@@ -333,7 +334,8 @@ response = openai.ChatCompletion.create(
|
||||
messages=[
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
stream=False
|
||||
stream=False,
|
||||
stop=[] # 在此处添加自定义的stop words 例如ReAct prompting时需要增加: stop=["Observation:"]。
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
@@ -68,6 +68,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_p: Optional[float] = None
|
||||
max_length: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[List[str]] = []
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
@@ -103,7 +104,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if request.messages[-1].role != "user":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
query = request.messages[-1].content
|
||||
|
||||
stop_words = request.stop
|
||||
stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words))))
|
||||
prev_messages = request.messages[:-1]
|
||||
# Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
|
||||
# if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
@@ -120,10 +122,18 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
raise HTTPException(status_code=400, detail="Invalid request.")
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, request.model)
|
||||
generate = predict(query, history, request.model, stop_words)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
if stop_words:
|
||||
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
|
||||
response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
|
||||
for stop_ in stop_words:
|
||||
if response.endswith(stop_):
|
||||
response = response[:response.find(stop_)]
|
||||
else:
|
||||
response, _ = model.chat(tokenizer, query, history=history)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
@@ -133,9 +143,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
||||
|
||||
|
||||
async def predict(query: str, history: List[List[str]], model_id: str):
|
||||
async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]):
|
||||
global model, tokenizer
|
||||
|
||||
assert stop_words == [], "in stream format, stop word is output"
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
@@ -145,8 +155,13 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
current_length = 0
|
||||
if stop_words:
|
||||
react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
|
||||
response_generator = model.chat_stream(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
|
||||
else:
|
||||
response_generator = model.chat_stream(tokenizer, query, history=history)
|
||||
|
||||
for new_response in model.chat_stream(tokenizer, query, history):
|
||||
for new_response in response_generator:
|
||||
if len(new_response) == current_length:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user