Update app.py
Browse files
app.py
CHANGED
@@ -3,19 +3,22 @@ import time
|
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
|
6 |
-
from strings import TITLE, ABSTRACT
|
7 |
from gen import get_pretrained_models, get_output
|
8 |
|
9 |
-
os.environ["RANK"] = "0"
|
10 |
-
os.environ["WORLD_SIZE"] = "1"
|
11 |
-
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
12 |
-
os.environ["MASTER_PORT"] = "50505"
|
13 |
-
|
14 |
generator = get_pretrained_models("13B", "tokenizer")
|
15 |
|
16 |
history = []
|
17 |
|
18 |
-
def chat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
bot_response = get_output(
|
20 |
generator=generator,
|
21 |
prompt=user_input,
|
@@ -24,13 +27,16 @@ def chat(user_input, top_p, temperature, max_gen_len, state_chatbot):
|
|
24 |
top_p=top_p)
|
25 |
|
26 |
# remove the first phrase identical to user prompt
|
27 |
-
|
28 |
-
|
|
|
|
|
29 |
# trip the last phrase
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
history.append({
|
36 |
"role": "user",
|
@@ -61,6 +67,14 @@ with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-ri
|
|
61 |
|
62 |
with gr.Column(elem_id='col_container'):
|
63 |
gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
chatbot = gr.Chatbot(elem_id='chatbot')
|
65 |
textbox = gr.Textbox(placeholder="Enter a prompt")
|
66 |
|
@@ -69,7 +83,11 @@ with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-ri
|
|
69 |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
|
70 |
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
71 |
|
72 |
-
textbox.submit(
|
|
|
|
|
|
|
|
|
73 |
textbox.submit(reset_textbox, [], [textbox])
|
74 |
|
75 |
demo.queue(api_open=False).launch()
|
|
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
|
6 |
+
from strings import TITLE, ABSTRACT, EXAMPLES
|
7 |
from gen import get_pretrained_models, get_output
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
generator = get_pretrained_models("13B", "tokenizer")
|
10 |
|
11 |
history = []
|
12 |
|
13 |
+
def chat(
|
14 |
+
user_input,
|
15 |
+
include_input,
|
16 |
+
truncate,
|
17 |
+
top_p,
|
18 |
+
temperature,
|
19 |
+
max_gen_len,
|
20 |
+
state_chatbot
|
21 |
+
):
|
22 |
bot_response = get_output(
|
23 |
generator=generator,
|
24 |
prompt=user_input,
|
|
|
27 |
top_p=top_p)
|
28 |
|
29 |
# remove the first phrase identical to user prompt
|
30 |
+
if not include_input:
|
31 |
+
bot_response = bot_response[0][len(user_input):]
|
32 |
+
bot_response = bot_response.replace("\n", "<br>")
|
33 |
+
|
34 |
# trip the last phrase
|
35 |
+
if truncate:
|
36 |
+
try:
|
37 |
+
bot_response = bot_response[:bot_response.rfind(".")+1]
|
38 |
+
except:
|
39 |
+
pass
|
40 |
|
41 |
history.append({
|
42 |
"role": "user",
|
|
|
67 |
|
68 |
with gr.Column(elem_id='col_container'):
|
69 |
gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
|
70 |
+
|
71 |
+
with gr.Accordion("Example prompts", open=False):
|
72 |
+
example_str = "\n"
|
73 |
+
for example in EXAMPLES:
|
74 |
+
example_str += f"- {example}\n"
|
75 |
+
|
76 |
+
gr.Markdown(example_str)
|
77 |
+
|
78 |
chatbot = gr.Chatbot(elem_id='chatbot')
|
79 |
textbox = gr.Textbox(placeholder="Enter a prompt")
|
80 |
|
|
|
83 |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
|
84 |
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
85 |
|
86 |
+
textbox.submit(
|
87 |
+
chat,
|
88 |
+
[textbox, include_input, truncate, top_p, temperature, max_gen_len, state_chatbot],
|
89 |
+
[state_chatbot, chatbot]
|
90 |
+
)
|
91 |
textbox.submit(reset_textbox, [], [textbox])
|
92 |
|
93 |
demo.queue(api_open=False).launch()
|