FredZhang7 commited on
Commit
402ce86
1 Parent(s): 80c7823

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -11
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import gc, copy, re
4
  import urllib.request
@@ -15,7 +14,7 @@ model = RWKV(model=title, strategy='cpu bf16')
15
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
16
 
17
  def generate_prompt(instruction, input=None, history=None):
18
- # Parse the chat history into a string of user and assistant messages
19
  history_str = ""
20
  for pair in history:
21
  history_str += f"Instruction: {pair[0]}\n\nAssistant: {pair[1]}\n\n"
@@ -51,7 +50,7 @@ def evaluate(
51
  top_p=0.5,
52
  presencePenalty = 0.5,
53
  countPenalty = 0.5,
54
- history=None # Add the history parameter to the evaluate function
55
  ):
56
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
57
  alpha_frequency = countPenalty,
@@ -61,7 +60,7 @@ def evaluate(
61
 
62
  instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
63
  input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
64
- ctx = generate_prompt(instruction, input, history) # Pass the history to the generate_prompt function
65
  print(ctx + "\n")
66
 
67
  all_tokens = []
@@ -119,23 +118,64 @@ with gr.Blocks(title=title) as demo:
119
  gr.Markdown(f"100% RNN RWKV-LM **trained on 100+ natural languages**. Demo limited to ctxlen {ctx_limit}. For best results, <b>keep your prompt short and clear</b>.")
120
  with gr.Row():
121
  with gr.Column():
122
- instruction = gr.Textbox(lines=2, label="Instruction", value="Please show me a table with a cheat sheet of Python's syntax.")
123
  input = gr.Textbox(lines=2, label="Input", placeholder="")
124
  token_count = gr.Slider(10, 512, label="Max Tokens", step=10, value=333)
125
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
126
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
127
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
128
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.7)
 
 
 
 
 
129
  data = gr.Dataset(components=[instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
 
 
130
  data.click(lambda x: x, [data], [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty])
131
 
132
  with gr.Tab("Chat mode"):
133
- chatbot = gr.ChatInterface(fn=evaluate,
134
- additional_inputs=[instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty],
135
- additional_inputs_accordion="Parameters",
136
- examples=["Hello", "Write a poem about love", "Generate a list of prime numbers"],
137
- title="RWKV Chatbot",
138
- description="A chatbot that can generate creative and informative content based on instructions and inputs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  demo.queue(max_size=10)
141
  demo.launch(share=False)
 
 
1
  import gradio as gr
2
  import gc, copy, re
3
  import urllib.request
 
14
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
15
 
16
  def generate_prompt(instruction, input=None, history=None):
17
+ # parse the chat history into a string of user and assistant messages
18
  history_str = ""
19
  for pair in history:
20
  history_str += f"Instruction: {pair[0]}\n\nAssistant: {pair[1]}\n\n"
 
50
  top_p=0.5,
51
  presencePenalty = 0.5,
52
  countPenalty = 0.5,
53
+ history=None # add the history parameter to the evaluate function
54
  ):
55
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
56
  alpha_frequency = countPenalty,
 
60
 
61
  instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
62
  input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
63
+ ctx = generate_prompt(instruction, input, history) # pass the history to the generate_prompt function
64
  print(ctx + "\n")
65
 
66
  all_tokens = []
 
118
  gr.Markdown(f"100% RNN RWKV-LM **trained on 100+ natural languages**. Demo limited to ctxlen {ctx_limit}. For best results, <b>keep your prompt short and clear</b>.")
119
  with gr.Row():
120
  with gr.Column():
121
+ instruction = gr.Textbox(lines=2, label="Instruction", value='東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。')
122
  input = gr.Textbox(lines=2, label="Input", placeholder="")
123
  token_count = gr.Slider(10, 512, label="Max Tokens", step=10, value=333)
124
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
125
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
126
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
127
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.7)
128
+ with gr.Column():
129
+ with gr.Row():
130
+ submit = gr.Button("Submit", variant="primary")
131
+ clear = gr.Button("Clear", variant="secondary")
132
+ output = gr.Textbox(label="Output", lines=5)
133
  data = gr.Dataset(components=[instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
134
+ submit.click(evaluate, [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty, []], [output])
135
+ clear.click(lambda: None, [], [output])
136
  data.click(lambda x: x, [data], [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty])
137
 
138
  with gr.Tab("Chat mode"):
139
+ with gr.Row():
140
+ chatbot = gr.Chatbot()
141
+ with gr.Column():
142
+ msg = gr.Textbox(scale=4, show_label=False, placeholder="Enter text and press enter", container=False)
143
+ clear = gr.Button("Clear")
144
+ with gr.Column():
145
+ token_count = gr.Slider(10, 512, label="Max Tokens", step=10, value=333)
146
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
147
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
148
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
149
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.7)
150
+
151
+ def clear_chat():
152
+ return "", []
153
+
154
+ def user_msg(message, history):
155
+ history = history or []
156
+ return "", history + [[message, None]]
157
+
158
+ def chat(history):
159
+ # get the last user message and the additional parameters
160
+ message = history[-1][0]
161
+ instruction = msg.value
162
+ token_count = token_count.value
163
+
164
+ temperature = temperature.value
165
+ top_p = top_p.value
166
+ presence_penalty = presence_penalty.value
167
+ count_penalty = count_penalty.value
168
+
169
+ response = evaluate(instruction, None, token_count, temperature, top_p, presence_penalty, count_penalty, history)
170
+
171
+ history[-1][1] = response
172
+ return history
173
+
174
+
175
+ msg.submit(user_msg, [msg, chatbot], [msg, chatbot], queue=False).then(
176
+ chat, chatbot, chatbot, api_name="chat"
177
+ )
178
+ clear.click(clear_chat, None, [chatbot], queue=False)
179
 
180
  demo.queue(max_size=10)
181
  demo.launch(share=False)