FredZhang7 commited on
Commit
c470f73
1 Parent(s): 5b27e8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -52
app.py CHANGED
@@ -8,7 +8,7 @@ ctx_limit = 4096
8
  title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
9
 
10
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
11
- model = RWKV(model=model_path, strategy='cpu bf16')
12
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
13
 
14
 
@@ -20,8 +20,15 @@ def generate_prompt(instruction, input=None, history=None):
20
  for pair in history:
21
  history_str += f"User: {pair[0]}\n\nAssistant: {pair[1]}\n\n"
22
 
23
- instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n').replace('\n\n','\n')
24
- input = input.strip().replace('\r\n','\n').replace('\n\n','\n').replace('\n\n','\n')
 
 
 
 
 
 
 
25
  if input and len(input) > 0:
26
  return f"""{history_str}Instruction: {instruction}
27
 
@@ -36,17 +43,50 @@ Assistant:"""
36
 
37
  examples = [
38
  ["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 300, 1.2, 0.5, 0.5, 0.5],
39
- ["Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.", "", 300, 1.2, 0.5, 0.5, 0.5],
 
 
 
 
 
 
 
 
40
  ["Write a song about ravens.", "", 300, 1.2, 0.5, 0.5, 0.5],
41
  ["Explain the following metaphor: Life is like cats.", "", 300, 1.2, 0.5, 0.5, 0.5],
42
- ["Write a story using the following information", "A man named Alex chops a tree down", 300, 1.2, 0.5, 0.5, 0.5],
43
- ["Generate a list of adjectives that describe a person as brave.", "", 300, 1.2, 0.5, 0.5, 0.5],
44
- ["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.5, 0.5],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ]
46
 
 
47
  def respond(history=None):
48
  global token_count_chat, temperature_chat, top_p_chat, presence_penalty_chat, count_penalty_chat
49
-
50
  # get the lastest user message and the additional parameters
51
  instruction = msg.value
52
  token_count = token_count_chat.value
@@ -57,42 +97,58 @@ def respond(history=None):
57
  count_penalty = count_penalty_chat.value
58
 
59
  history[-1][1] = ""
60
-
61
- for character in generator(instruction, None, token_count, temperature, top_p, presence_penalty, count_penalty, history):
 
 
 
 
 
 
 
 
62
  history[-1][1] += character
63
  yield history
64
 
 
65
  def generator(
66
  instruction,
67
  input=None,
68
  token_count=333,
69
  temperature=1.0,
70
  top_p=0.5,
71
- presencePenalty = 0.5,
72
- countPenalty = 0.5
73
  ):
74
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
75
- alpha_frequency = countPenalty,
76
- alpha_presence = presencePenalty,
77
- token_ban = [], # ban the generation of some tokens
78
- token_stop = [0]) # stop generation whenever you see any token here
79
-
80
- instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
81
- input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
 
 
 
82
  ctx = generate_prompt(instruction, input, history)
83
  print(ctx + "\n")
84
-
85
  all_tokens = []
86
  out_last = 0
87
- out_str = ''
88
  occurrence = {}
89
  state = None
90
  for i in range(int(token_count)):
91
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
 
92
  for n in occurrence:
93
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
94
 
95
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
 
 
96
  if token in args.token_stop:
97
  break
98
  all_tokens += [token]
@@ -102,13 +158,13 @@ def generator(
102
  occurrence[token] = 1
103
  else:
104
  occurrence[token] += 1
105
-
106
  tmp = pipeline.decode(all_tokens[out_last:])
107
- if '\ufffd' not in tmp:
108
  out_str += tmp
109
  yield out_str.strip()
110
  out_last = i + 1
111
- if '\n\n' in out_str:
112
  break
113
 
114
  del out
@@ -116,14 +172,16 @@ def generator(
116
  gc.collect()
117
  yield out_str.strip()
118
 
 
119
  def user(message, chatbot):
120
  chatbot = chatbot or []
121
  return "", chatbot + [[message, None]]
122
 
 
123
  def alternative(chatbot, history):
124
  if not chatbot or not history:
125
  return chatbot, history
126
-
127
  chatbot[-1][1] = None
128
  history[0] = copy.deepcopy(history[1])
129
 
@@ -131,53 +189,129 @@ def alternative(chatbot, history):
131
 
132
 
133
  with gr.Blocks(title=title) as demo:
134
- gr.HTML(f"<div style=\"text-align: center;\">\n<h1>🌍World - {title}</h1>\n</div>")
135
-
136
  with gr.Tab("Chat mode"):
137
  with gr.Row():
138
  with gr.Column():
139
  chatbot = gr.Chatbot()
140
- msg = gr.Textbox(scale=4, show_label=False, placeholder="Enter text and press enter", container=False)
 
 
 
 
 
141
  clear = gr.ClearButton([msg, chatbot])
142
  with gr.Column():
143
- token_count_chat = gr.Slider(10, 512, label="Max Tokens", step=10, value=333)
144
- temperature_chat = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
 
 
 
 
145
  top_p_chat = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
146
- presence_penalty_chat = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
147
- count_penalty_chat = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.7)
148
-
 
 
 
 
149
  def clear_chat():
150
  return "", []
151
-
152
  def user_msg(message, history):
153
  history = history or []
154
  return "", history + [[message, None]]
155
-
156
  msg.submit(user_msg, [msg, chatbot], [msg, chatbot], queue=False).then(
157
  respond, chatbot, chatbot, api_name="chat"
158
  )
159
-
160
  with gr.Tab("Instruct mode"):
161
- gr.Markdown(f"100% RNN RWKV-LM **trained on 100+ natural languages**. Demo limited to ctxlen {ctx_limit}. For best results, <b>keep your prompt short and clear</b>.")
 
 
162
  with gr.Row():
163
  with gr.Column():
164
- instruction = gr.Textbox(lines=2, label="Instruction", value='東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。')
165
- input_instruct = gr.Textbox(lines=2, label="Input", placeholder="", value="")
166
- token_count_instruct = gr.Slider(10, 512, label="Max Tokens", step=10, value=333)
167
- temperature_instruct = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
168
- top_p_instruct = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
169
- presence_penalty_instruct = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
170
- count_penalty_instruct = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  with gr.Column():
172
  with gr.Row():
173
  submit = gr.Button("Submit", variant="primary")
174
  clear = gr.Button("Clear", variant="secondary")
175
  output = gr.Textbox(label="Output", lines=5)
176
- data = gr.Dataset(components=[instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
177
- submit.click(generator, [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], [output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  clear.click(lambda: None, [], [output])
179
- data.click(lambda x: x, [data], [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct])
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  demo.queue(max_size=10)
183
- demo.launch(share=False)
 
8
  title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
9
 
10
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
11
+ model = RWKV(model=model_path, strategy="cpu bf16")
12
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
13
 
14
 
 
20
  for pair in history:
21
  history_str += f"User: {pair[0]}\n\nAssistant: {pair[1]}\n\n"
22
 
23
+ instruction = (
24
+ instruction.strip()
25
+ .replace("\r\n", "\n")
26
+ .replace("\n\n", "\n")
27
+ .replace("\n\n", "\n")
28
+ )
29
+ input = (
30
+ input.strip().replace("\r\n", "\n").replace("\n\n", "\n").replace("\n\n", "\n")
31
+ )
32
  if input and len(input) > 0:
33
  return f"""{history_str}Instruction: {instruction}
34
 
 
43
 
44
  examples = [
45
  ["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 300, 1.2, 0.5, 0.5, 0.5],
46
+ [
47
+ "Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.",
48
+ "",
49
+ 300,
50
+ 1.2,
51
+ 0.5,
52
+ 0.5,
53
+ 0.5,
54
+ ],
55
  ["Write a song about ravens.", "", 300, 1.2, 0.5, 0.5, 0.5],
56
  ["Explain the following metaphor: Life is like cats.", "", 300, 1.2, 0.5, 0.5, 0.5],
57
+ [
58
+ "Write a story using the following information",
59
+ "A man named Alex chops a tree down",
60
+ 300,
61
+ 1.2,
62
+ 0.5,
63
+ 0.5,
64
+ 0.5,
65
+ ],
66
+ [
67
+ "Generate a list of adjectives that describe a person as brave.",
68
+ "",
69
+ 300,
70
+ 1.2,
71
+ 0.5,
72
+ 0.5,
73
+ 0.5,
74
+ ],
75
+ [
76
+ "You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.",
77
+ "",
78
+ 300,
79
+ 1.2,
80
+ 0.5,
81
+ 0.5,
82
+ 0.5,
83
+ ],
84
  ]
85
 
86
+
87
  def respond(history=None):
88
  global token_count_chat, temperature_chat, top_p_chat, presence_penalty_chat, count_penalty_chat
89
+
90
  # get the lastest user message and the additional parameters
91
  instruction = msg.value
92
  token_count = token_count_chat.value
 
97
  count_penalty = count_penalty_chat.value
98
 
99
  history[-1][1] = ""
100
+
101
+ for character in generator(
102
+ instruction,
103
+ None,
104
+ token_count,
105
+ temperature,
106
+ top_p,
107
+ presence_penalty,
108
+ count_penalty,
109
+ ):
110
  history[-1][1] += character
111
  yield history
112
 
113
+
114
  def generator(
115
  instruction,
116
  input=None,
117
  token_count=333,
118
  temperature=1.0,
119
  top_p=0.5,
120
+ presencePenalty=0.5,
121
+ countPenalty=0.5,
122
  ):
123
+ args = PIPELINE_ARGS(
124
+ temperature=max(0.2, float(temperature)),
125
+ top_p=float(top_p),
126
+ alpha_frequency=countPenalty,
127
+ alpha_presence=presencePenalty,
128
+ token_ban=[], # ban the generation of some tokens
129
+ token_stop=[0],
130
+ ) # stop generation whenever you see any token here
131
+
132
+ instruction = re.sub(r"\n{2,}", "\n", instruction).strip().replace("\r\n", "\n")
133
+ input = re.sub(r"\n{2,}", "\n", input).strip().replace("\r\n", "\n")
134
  ctx = generate_prompt(instruction, input, history)
135
  print(ctx + "\n")
136
+
137
  all_tokens = []
138
  out_last = 0
139
+ out_str = ""
140
  occurrence = {}
141
  state = None
142
  for i in range(int(token_count)):
143
+ out, state = model.forward(
144
+ pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state
145
+ )
146
  for n in occurrence:
147
+ out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
148
 
149
+ token = pipeline.sample_logits(
150
+ out, temperature=args.temperature, top_p=args.top_p
151
+ )
152
  if token in args.token_stop:
153
  break
154
  all_tokens += [token]
 
158
  occurrence[token] = 1
159
  else:
160
  occurrence[token] += 1
161
+
162
  tmp = pipeline.decode(all_tokens[out_last:])
163
+ if "\ufffd" not in tmp:
164
  out_str += tmp
165
  yield out_str.strip()
166
  out_last = i + 1
167
+ if "\n\n" in out_str:
168
  break
169
 
170
  del out
 
172
  gc.collect()
173
  yield out_str.strip()
174
 
175
+
176
  def user(message, chatbot):
177
  chatbot = chatbot or []
178
  return "", chatbot + [[message, None]]
179
 
180
+
181
  def alternative(chatbot, history):
182
  if not chatbot or not history:
183
  return chatbot, history
184
+
185
  chatbot[-1][1] = None
186
  history[0] = copy.deepcopy(history[1])
187
 
 
189
 
190
 
191
  with gr.Blocks(title=title) as demo:
192
+ gr.HTML(f'<div style="text-align: center;">\n<h1>🌍World - {title}</h1>\n</div>')
193
+
194
  with gr.Tab("Chat mode"):
195
  with gr.Row():
196
  with gr.Column():
197
  chatbot = gr.Chatbot()
198
+ msg = gr.Textbox(
199
+ scale=4,
200
+ show_label=False,
201
+ placeholder="Enter text and press enter",
202
+ container=False,
203
+ )
204
  clear = gr.ClearButton([msg, chatbot])
205
  with gr.Column():
206
+ token_count_chat = gr.Slider(
207
+ 10, 512, label="Max Tokens", step=10, value=333
208
+ )
209
+ temperature_chat = gr.Slider(
210
+ 0.2, 2.0, label="Temperature", step=0.1, value=1.2
211
+ )
212
  top_p_chat = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
213
+ presence_penalty_chat = gr.Slider(
214
+ 0.0, 1.0, label="Presence Penalty", step=0.1, value=0
215
+ )
216
+ count_penalty_chat = gr.Slider(
217
+ 0.0, 1.0, label="Count Penalty", step=0.1, value=0.7
218
+ )
219
+
220
  def clear_chat():
221
  return "", []
222
+
223
  def user_msg(message, history):
224
  history = history or []
225
  return "", history + [[message, None]]
226
+
227
  msg.submit(user_msg, [msg, chatbot], [msg, chatbot], queue=False).then(
228
  respond, chatbot, chatbot, api_name="chat"
229
  )
230
+
231
  with gr.Tab("Instruct mode"):
232
+ gr.Markdown(
233
+ f"100% RNN RWKV-LM **trained on 100+ natural languages**. Demo limited to ctxlen {ctx_limit}. For best results, <b>keep your prompt short and clear</b>."
234
+ )
235
  with gr.Row():
236
  with gr.Column():
237
+ instruction = gr.Textbox(
238
+ lines=2,
239
+ label="Instruction",
240
+ value="東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。",
241
+ )
242
+ input_instruct = gr.Textbox(
243
+ lines=2, label="Input", placeholder="", value=""
244
+ )
245
+ token_count_instruct = gr.Slider(
246
+ 10, 512, label="Max Tokens", step=10, value=333
247
+ )
248
+ temperature_instruct = gr.Slider(
249
+ 0.2, 2.0, label="Temperature", step=0.1, value=1.2
250
+ )
251
+ top_p_instruct = gr.Slider(
252
+ 0.0, 1.0, label="Top P", step=0.05, value=0.3
253
+ )
254
+ presence_penalty_instruct = gr.Slider(
255
+ 0.0, 1.0, label="Presence Penalty", step=0.1, value=0
256
+ )
257
+ count_penalty_instruct = gr.Slider(
258
+ 0.0, 1.0, label="Count Penalty", step=0.1, value=0.7
259
+ )
260
  with gr.Column():
261
  with gr.Row():
262
  submit = gr.Button("Submit", variant="primary")
263
  clear = gr.Button("Clear", variant="secondary")
264
  output = gr.Textbox(label="Output", lines=5)
265
+ data = gr.Dataset(
266
+ components=[
267
+ instruction,
268
+ input_instruct,
269
+ token_count_instruct,
270
+ temperature_instruct,
271
+ top_p_instruct,
272
+ presence_penalty_instruct,
273
+ count_penalty_instruct,
274
+ ],
275
+ samples=examples,
276
+ label="Example Instructions",
277
+ headers=[
278
+ "Instruction",
279
+ "Input",
280
+ "Max Tokens",
281
+ "Temperature",
282
+ "Top P",
283
+ "Presence Penalty",
284
+ "Count Penalty",
285
+ ],
286
+ )
287
+ submit.click(
288
+ generator,
289
+ [
290
+ instruction,
291
+ input_instruct,
292
+ token_count_instruct,
293
+ temperature_instruct,
294
+ top_p_instruct,
295
+ presence_penalty_instruct,
296
+ count_penalty_instruct,
297
+ ],
298
+ [output],
299
+ )
300
  clear.click(lambda: None, [], [output])
301
+ data.click(
302
+ lambda x: x,
303
+ [data],
304
+ [
305
+ instruction,
306
+ input_instruct,
307
+ token_count_instruct,
308
+ temperature_instruct,
309
+ top_p_instruct,
310
+ presence_penalty_instruct,
311
+ count_penalty_instruct,
312
+ ],
313
+ )
314
 
315
 
316
  demo.queue(max_size=10)
317
+ demo.launch(share=False)