Update openai_api.py

This commit is contained in:
Junyang Lin
2023-08-13 12:28:05 +08:00
committed by GitHub
parent 3006ef34e9
commit dfced1aec0

View File

@@ -35,6 +35,7 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: str = "model" object: str = "model"
@@ -136,7 +137,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield "{}".format(chunk.model_dump_json(exclude_unset=True))
current_length = 0 current_length = 0
@@ -154,7 +154,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield "{}".format(chunk.model_dump_json(exclude_unset=True))
@@ -164,13 +163,12 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]' yield '[DONE]'
def _get_args(): def _get_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat',
help="Checkpoint name or path, default to %(default)r") help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--server-port", type=int, default=8000, parser.add_argument("--server-port", type=int, default=8000,
@@ -181,7 +179,6 @@ def _get_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat'
if __name__ == "__main__": if __name__ == "__main__":
args = _get_args() args = _get_args()