better install shell based on issue comment

This commit is contained in:
wsl-wy
2023-08-08 00:45:59 +08:00
parent ad66116fe5
commit 92c5c47a4c
2 changed files with 77 additions and 42 deletions

View File

@@ -7,21 +7,14 @@ import gradio as gr
import mdtex2html
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from argparse import ArgumentParser
import sys
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
)
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-7B-Chat",
device_map="auto",
offload_folder="offload",
trust_remote_code=True,
resume_download=True,
).eval()
model.generation_config = GenerationConfig.from_pretrained(
"Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", offload_folder="offload", trust_remote_code=True, resume_download=True).eval()
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True)
if len(sys.argv) > 1 and sys.argv[1] == "--exit":
sys.exit(0)
@@ -82,7 +75,7 @@ def predict(input, chatbot):
chatbot.append((parse_text(input), ""))
fullResponse = ""
for response in model.chat(tokenizer, input, history=task_history, stream=True):
for response in model.chat_stream(tokenizer, input, history=task_history):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot
@@ -108,9 +101,7 @@ with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
query = gr.Textbox(
show_label=False, placeholder="Input...", lines=10
).style(container=False)
query = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
@@ -120,4 +111,18 @@ with gr.Blocks() as demo:
submitBtn.click(reset_user_input, [], [query])
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True)
demo.queue().launch(share=False, inbrowser=True, server_port=80, server_name="0.0.0.0")
if len(sys.argv) > 1:
print("Call args:" + str(sys.argv))
parser = ArgumentParser()
parser.add_argument("--share", action="store_true", default=False)
parser.add_argument("--inbrowser", action="store_true", default=False)
parser.add_argument("--server_port", type=int, default=80)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
args = parser.parse_args(sys.argv[1:])
print("Args:" + str(args))
print("Args:" + str(args))
demo.queue().launch(args)
else:
demo.queue().launch()