diff --git a/web_demo.py b/web_demo.py
index 8b37b02..9f812cf 100755
--- a/web_demo.py
+++ b/web_demo.py
@@ -9,12 +9,23 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import sys
-tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
-model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
-model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
+tokenizer = AutoTokenizer.from_pretrained(
+ "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
+)
+model = AutoModelForCausalLM.from_pretrained(
+ "Qwen/Qwen-7B-Chat",
+ device_map="auto",
+ offload_folder="offload",
+ trust_remote_code=True,
+ resume_download=True,
+).eval()
+model.generation_config = GenerationConfig.from_pretrained(
+ "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
+)
if len(sys.argv) > 1 and sys.argv[1] == "--exit":
- exit(0)
+ sys.exit(0)
+
def postprocess(self, y):
if y is None:
@@ -58,21 +69,26 @@ def parse_text(text):
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
- lines[i] = "
"+line
+ lines[i] = "
" + line
text = "".join(lines)
return text
-def predict(input, chatbot, history, past_key_values):
+
+task_history = []
+
+
+def predict(input, chatbot):
print('Q: ' + parse_text(input))
chatbot.append((parse_text(input), ""))
- fullResponse = "";
+ fullResponse = ""
- for response in model.chat(tokenizer, input, history=history, stream=True):
+ for response in model.chat(tokenizer, input, history=task_history, stream=True):
chatbot[-1] = (parse_text(input), parse_text(response))
- yield chatbot, history, past_key_values
- fullResponse = parse_text(response);
-
+ yield chatbot
+ fullResponse = parse_text(response)
+
+ task_history.append((input, fullResponse))
print("A: " + parse_text(fullResponse))
@@ -81,7 +97,8 @@ def reset_user_input():
def reset_state():
- return [], [], None
+ task_history = []
+ return []
with gr.Blocks() as demo:
@@ -91,19 +108,16 @@ with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
- container=False)
+ query = gr.Textbox(
+ show_label=False, placeholder="Input...", lines=10
+ ).style(container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
- history = gr.State([])
- past_key_values = gr.State(None)
-
- submitBtn.click(predict, [user_input, chatbot, history, past_key_values],
- [chatbot, history, past_key_values], show_progress=True)
- submitBtn.click(reset_user_input, [], [user_input])
- emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
+ submitBtn.click(predict, [query, chatbot], [chatbot], show_progress=True)
+ submitBtn.click(reset_user_input, [], [query])
+ emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True)
demo.queue().launch(share=False, inbrowser=True, server_port=80, server_name="0.0.0.0")