cryscan commited on
Commit
f19425c
1 Parent(s): b4c965c

1. Fix clearing

Browse files

2. Add custom chat scenarios.

Files changed (1) hide show
  1. app.py +79 -125
app.py CHANGED
@@ -1,9 +1,5 @@
1
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
2
- from rwkv.model import RWKV
3
  import gradio as gr
4
- import os
5
- import gc
6
- import torch
7
  from datetime import datetime
8
  from huggingface_hub import hf_hub_download
9
  from pynvml import *
@@ -11,35 +7,32 @@ nvmlInit()
11
  gpu_h = nvmlDeviceGetHandleByIndex(0)
12
  ctx_limit = 1024
13
  title = "RWKV-4-Pile-14B-20230313-ctx8192-test1050"
14
- desc = f'''Links:
15
- <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a>
16
- <a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a>
17
- <a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a>
18
  '''
19
 
20
  os.environ["RWKV_JIT_ON"] = '1'
21
- # if '1' then use CUDA kernel for seq mode (much faster)
22
- os.environ["RWKV_CUDA_ON"] = '1'
23
 
 
24
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth")
25
- model = RWKV(model=model_path, strategy='cuda fp16i8 *20 -> cuda fp16')
26
- pipeline = PIPELINE(model, "20B_tokenizer.json")
27
 
28
- ########################################################################################################
 
29
 
30
  def infer(
31
  ctx,
32
  token_count=10,
33
  temperature=1.0,
34
  top_p=0.8,
35
- presence_enalty=0.1,
36
- count_penalty=0.1,
37
  ):
38
- args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
39
- alpha_frequency=float(count_penalty),
40
- alpha_presence=float(presence_enalty),
41
- token_ban=[0], # ban the generation of some tokens
42
- token_stop=[]) # stop generation whenever you see any token here
43
 
44
  ctx = ctx.strip(' ')
45
  if ctx.endswith('\n'):
@@ -49,7 +42,7 @@ def infer(
49
 
50
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
51
  print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
52
-
53
  all_tokens = []
54
  out_last = 0
55
  out_str = ''
@@ -70,7 +63,7 @@ def infer(
70
  occurrence[token] = 1
71
  else:
72
  occurrence[token] += 1
73
-
74
  tmp = pipeline.decode(all_tokens[out_last:])
75
  if '\ufffd' not in tmp:
76
  out_str += tmp
@@ -110,92 +103,70 @@ Arrange the given numbers in ascending order.
110
  ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
111
  ]
112
 
113
- # infer_interface = gr.Interface(
114
- # fn=infer,
115
- # description=f'''{desc} <b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
116
- # allow_flagging="never",
117
- # inputs=[
118
- # gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
119
- # gr.Slider(10, 200, step=10, value=150), # token_count
120
- # gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
121
- # gr.Slider(0.0, 1.0, step=0.05, value=0.7), # top_p
122
- # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presencePenalty
123
- # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # countPenalty
124
- # ],
125
- # outputs=gr.Textbox(label="Generated Output", lines=28),
126
- # examples=examples,
127
- # cache_examples=False,
128
- # ).queue()
129
-
130
- ########################################################################################################
131
-
132
- user = "Bob"
133
- bot = "Alice"
134
- interface = ":"
135
-
136
- chat_intro = f'''
137
- The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
138
- {bot} is very intelligent, creative and friendly. \
139
- She is unlikely to disagree with {user}, and she doesn't like to ask {user} questions. \
140
- She also likes to tell {user} a lot about herself and her opinions, and she usually gives {user} kind, helpful and informative advices.
141
-
142
- {user}{interface} Hello, how are you doing?
143
 
144
- {bot}{interface} Hi {user}! Thanks, I'm fine. What about you?
145
 
146
- {user}{interface} I am fine. It's nice to see you. Look, here is a store selling tea and juice.
147
 
148
- {bot}{interface} Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!
149
-
150
- {user}{interface} What is it?
151
-
152
- {bot}{interface} Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its flavors are frequently sweet.
153
-
154
- {user}{interface} Sounds tasty. I'll try it next time. Would you like to chat with me for a while?
155
-
156
- {bot}{interface} Of course! I'm glad to answer your questions or give helpful advices. You know, I am confident with my expertise. So please go ahead!
157
 
 
158
  '''
159
 
160
- _, intro_state = model.forward(pipeline.encode(chat_intro), None)
161
-
162
- def user(user_message, chatbot):
163
  chatbot = chatbot or []
164
- return "", chatbot + [[user_message, None]]
 
165
 
166
  def chat(
 
 
 
167
  chatbot,
168
  history,
169
- token_count=10,
170
  temperature=1.0,
171
  top_p=0.8,
172
- presence_enalty=0.1,
173
  count_penalty=0.1,
174
  ):
175
  args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
176
  alpha_frequency=float(count_penalty),
177
- alpha_presence=float(presence_enalty),
178
  token_ban=[], # ban the generation of some tokens
179
  token_stop=[]) # stop generation whenever you see any token here
180
 
181
  message = chatbot[-1][0]
182
- message = message.strip(' ')
183
- message = message.replace('\n', '')
184
- ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
185
 
186
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
187
- print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
188
 
189
- history = history or [intro_state, []] # [chat, state, all_tokens]
 
 
 
 
 
 
 
 
190
 
191
  [state, all_tokens] = history
192
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
193
 
 
 
194
  begin = len(all_tokens)
195
  out_last = begin
196
  out_str: str = ''
197
  occurrence = {}
198
- for i in range(int(token_count)):
199
  if i <= 0:
200
  nl_bias = -float('inf')
201
  elif i <= 30:
@@ -239,77 +210,60 @@ def chat(
239
  history = [state, all_tokens]
240
  return chatbot, history
241
 
242
- # chat_interface = gr.Interface(
243
- # fn=chat,
244
- # description=f'''You are {user}, bot is {bot}.''',
245
- # allow_flagging="never",
246
- # inputs = [
247
- # gr.Textbox(label="Message"),
248
- # "state",
249
- # gr.Slider(10, 1000, step=10, value=250), # token_count
250
- # gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
251
- # gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p
252
- # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presence_penalty
253
- # gr.Slider(0.0, 1.0, step=0.1, value=0.2), # count_penalty
254
- # ],
255
- # outputs=[
256
- # gr.Chatbot(label="Chat Log", color_map=("blue", "pink")),
257
- # "state"
258
- # ]
259
- # ).queue()
260
-
261
- ########################################################################################################
262
-
263
- # demo = gr.TabbedInterface(
264
- # [infer_interface, chat_interface], ["Generative", "Chat"],
265
- # title=title,
266
- # )
267
-
268
- # demo.queue(max_size=10)
269
- # demo.launch(share=True)
270
-
271
  with gr.Blocks(title=title) as demo:
272
  with gr.Tab("Generative"):
273
- gr.Markdown(f'''{desc}<b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''', label="Description")
274
  with gr.Row():
275
  with gr.Column():
276
  prompt = gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n")
277
- token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
 
 
 
278
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
279
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
280
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
281
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
282
  with gr.Column():
283
- with gr.Row():
284
- submit = gr.Button("Submit")
285
- clear = gr.Button("Clear")
286
- output = gr.Textbox(label="Generated Output", lines=28)
287
  data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Prompts", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
288
  submit.click(infer, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
289
  clear.click(lambda: None, [], [output])
290
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
 
291
  with gr.Tab("Chat"):
292
- gr.Markdown(f'''{desc}Scenario: You are Bob, bot is Alice. You meet at a café.''', label="Description")
293
  with gr.Row():
294
  with gr.Column():
295
  chatbot = gr.Chatbot()
296
  state = gr.State()
297
  message = gr.Textbox(label="Message")
298
  with gr.Row():
299
- send = gr.Button("Send")
300
- clear = gr.Button("Clear")
301
  with gr.Column():
302
- token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
 
 
 
303
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
304
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
305
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
306
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
307
- message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(
308
- chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
309
- )
310
- send.click(user, [message, chatbot], [message, chatbot], queue=False).then(
311
- chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
312
- )
 
 
 
 
 
 
 
 
313
  clear.click(lambda: ([], None, ""), [], [chatbot, state, message])
314
 
315
  demo.queue(max_size=10)
 
 
 
1
  import gradio as gr
2
+ import os, gc, torch
 
 
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
 
7
  gpu_h = nvmlDeviceGetHandleByIndex(0)
8
  ctx_limit = 1024
9
  title = "RWKV-4-Pile-14B-20230313-ctx8192-test1050"
10
+ desc = f'''Links:<a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a><a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a><a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a><a href="https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B" target="_blank" style="margin:0 0.5em">Raven 7B (alpaca-style)</a>
 
 
 
11
  '''
12
 
13
  os.environ["RWKV_JIT_ON"] = '1'
14
+ os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
 
15
 
16
+ from rwkv.model import RWKV
17
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth")
18
+ model = RWKV(model=model_path, strategy='cuda fp16i8 *24 -> cuda fp16')
 
19
 
20
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
21
+ pipeline = PIPELINE(model, "20B_tokenizer.json")
22
 
23
  def infer(
24
  ctx,
25
  token_count=10,
26
  temperature=1.0,
27
  top_p=0.8,
28
+ presencePenalty = 0.1,
29
+ countPenalty = 0.1,
30
  ):
31
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
32
+ alpha_frequency = countPenalty,
33
+ alpha_presence = presencePenalty,
34
+ token_ban = [0], # ban the generation of some tokens
35
+ token_stop = []) # stop generation whenever you see any token here
36
 
37
  ctx = ctx.strip(' ')
38
  if ctx.endswith('\n'):
 
42
 
43
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
44
  print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
45
+
46
  all_tokens = []
47
  out_last = 0
48
  out_str = ''
 
63
  occurrence[token] = 1
64
  else:
65
  occurrence[token] += 1
66
+
67
  tmp = pipeline.decode(all_tokens[out_last:])
68
  if '\ufffd' not in tmp:
69
  out_str += tmp
 
103
  ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
104
  ]
105
 
106
+ chat_intro = '''The following is a coherent verbose detailed conversation between an AI girl named <|bot|> and <|user|>. One day, they meet at a café.
107
+ Note the following important facts about <|bot|>:
108
+ 1. <|bot|> is very intelligent, creative and friendly.
109
+ 2. <|bot|> likes to tell <|user|> a lot about herself and her opinions.
110
+ 3. <|bot|> usually gives <|user|> kind, helpful and informative advices.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ <|user|>: Hello, how are you doing?
113
 
114
+ <|bot|>: Hi! Thanks, I'm fine. What about you?
115
 
116
+ <|user|>: I am fine. It's nice to see you. Look, here is a store selling tea and juice. We can go and take a look. Would you like to chat with me for a while?
 
 
 
 
 
 
 
 
117
 
118
+ <|bot|>: Sure. Let's go inside. What would you like to talk about? I'm listening.
119
  '''
120
 
121
+ def user(message, chatbot):
 
 
122
  chatbot = chatbot or []
123
+ print(f"User: {message}")
124
+ return "", chatbot + [[message, None]]
125
 
126
  def chat(
127
+ prompt,
128
+ user,
129
+ bot,
130
  chatbot,
131
  history,
 
132
  temperature=1.0,
133
  top_p=0.8,
134
+ presence_penalty=0.1,
135
  count_penalty=0.1,
136
  ):
137
  args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
138
  alpha_frequency=float(count_penalty),
139
+ alpha_presence=float(presence_penalty),
140
  token_ban=[], # ban the generation of some tokens
141
  token_stop=[]) # stop generation whenever you see any token here
142
 
143
  message = chatbot[-1][0]
144
+ message = message.strip().replace('\r\n','\n').replace('\n\n','\n')
145
+ ctx = f"{user}: {message}\n\n{bot}:"
 
146
 
147
+ # gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
148
+ # print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
149
 
150
+ if not history:
151
+ prompt = prompt.replace("<|user|>", user.strip())
152
+ prompt = prompt.replace("<|bot|>", bot.strip())
153
+ prompt = prompt.strip()
154
+ prompt = f"\n{prompt}\n\n"
155
+
156
+ out, state = model.forward(pipeline.encode(prompt), None)
157
+ history = [state, []]
158
+ print("History reloaded.")
159
 
160
  [state, all_tokens] = history
161
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
162
 
163
+ print("Bot: ", end='')
164
+
165
  begin = len(all_tokens)
166
  out_last = begin
167
  out_str: str = ''
168
  occurrence = {}
169
+ for i in range(300):
170
  if i <= 0:
171
  nl_bias = -float('inf')
172
  elif i <= 30:
 
210
  history = [state, all_tokens]
211
  return chatbot, history
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  with gr.Blocks(title=title) as demo:
214
  with gr.Tab("Generative"):
215
+ gr.Markdown(f'''{desc} *** <b>Please try examples first (bottom of page)</b> *** (edit them to your own question).\nDemo limited to ctxlen {ctx_limit}.''', label="Description")
216
  with gr.Row():
217
  with gr.Column():
218
  prompt = gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n")
219
+ with gr.Row():
220
+ submit = gr.Button("Submit", variant="primary")
221
+ clear = gr.Button("Clear", variant="secondary")
222
+ token_count = gr.Slider(10, 200, label="Max Tokens", step=10, value=150)
223
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
224
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.7)
225
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
226
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
227
  with gr.Column():
228
+ output = gr.Textbox(label="Generated Output", lines=32)
 
 
 
229
  data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Prompts", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
230
  submit.click(infer, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
231
  clear.click(lambda: None, [], [output])
232
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
233
+
234
  with gr.Tab("Chat"):
235
+ gr.Markdown(f'''{desc} *** <b>Default Chat Scenario: You (Bob) and Bot (Alice) meet at a café.</b> ***\nIf you want to change the scenario, make sure to use an empty new line to separate different people's words. Also, make sure there is no empty new lines within one person's lines. Changes only take effect after clearing.''', label="Description")
236
  with gr.Row():
237
  with gr.Column():
238
  chatbot = gr.Chatbot()
239
  state = gr.State()
240
  message = gr.Textbox(label="Message")
241
  with gr.Row():
242
+ send = gr.Button("Send", variant="primary")
243
+ clear = gr.Button("Clear", variant="secondary")
244
  with gr.Column():
245
+ with gr.Row():
246
+ user_name = gr.Textbox(lines=1, max_lines=1, label="User Name", value="Bob")
247
+ bot_name = gr.Textbox(lines=1, max_lines=1, label="Bot Name", value="Alice")
248
+ prompt = gr.Textbox(lines=10, max_lines=50, label="Scenario", value=chat_intro)
249
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
250
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.7)
251
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
252
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
253
+ chat_inputs = [
254
+ prompt,
255
+ user_name,
256
+ bot_name,
257
+ chatbot,
258
+ state,
259
+ temperature,
260
+ top_p,
261
+ presence_penalty,
262
+ count_penalty
263
+ ]
264
+ chat_outputs = [chatbot, state]
265
+ message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
266
+ send.click(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
267
  clear.click(lambda: ([], None, ""), [], [chatbot, state, message])
268
 
269
  demo.queue(max_size=10)