cryscan commited on
Commit
1834a63
1 Parent(s): e176af0

Implement basic chat mode

Browse files

Add a chat mode. The functionalities are very basic, but hopefully work.

Files changed (1) hide show
  1. app.py +151 -20
app.py CHANGED
@@ -1,5 +1,9 @@
 
 
1
  import gradio as gr
2
- import os, gc, torch
 
 
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
@@ -11,31 +15,31 @@ desc = f'''Links:
11
  <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a>
12
  <a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a>
13
  <a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a>
14
- <a href="https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B" target="_blank" style="margin:0 0.5em">Raven 7B (alpaca-style)</a>
15
  '''
16
 
17
  os.environ["RWKV_JIT_ON"] = '1'
18
- os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
 
19
 
20
- from rwkv.model import RWKV
21
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth")
22
- model = RWKV(model=model_path, strategy='cuda fp16i8 *24 -> cuda fp16')
23
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
24
  pipeline = PIPELINE(model, "20B_tokenizer.json")
25
 
 
 
26
  def infer(
27
  ctx,
28
  token_count=10,
29
  temperature=1.0,
30
  top_p=0.8,
31
- presencePenalty = 0.1,
32
- countPenalty = 0.1,
33
  ):
34
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
35
- alpha_frequency = countPenalty,
36
- alpha_presence = presencePenalty,
37
- token_ban = [0], # ban the generation of some tokens
38
- token_stop = []) # stop generation whenever you see any token here
39
 
40
  ctx = ctx.strip(' ')
41
  if ctx.endswith('\n'):
@@ -45,7 +49,7 @@ def infer(
45
 
46
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
47
  print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
48
-
49
  all_tokens = []
50
  out_last = 0
51
  out_str = ''
@@ -66,7 +70,7 @@ def infer(
66
  occurrence[token] = 1
67
  else:
68
  occurrence[token] += 1
69
-
70
  tmp = pipeline.decode(all_tokens[out_last:])
71
  if '\ufffd' not in tmp:
72
  out_str += tmp
@@ -106,10 +110,9 @@ Arrange the given numbers in ascending order.
106
  ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
107
  ]
108
 
109
-
110
- iface = gr.Interface(
111
  fn=infer,
112
- description=f'''{desc} *** <b>Please try examples first (bottom of page)</b> *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
113
  allow_flagging="never",
114
  inputs=[
115
  gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
@@ -124,10 +127,138 @@ iface = gr.Interface(
124
  cache_examples=False,
125
  ).queue()
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  demo = gr.TabbedInterface(
128
- [iface], ["Generative"],
129
  title=title,
130
  )
131
 
132
  demo.queue(max_size=10)
133
- demo.launch(share=False)
 
1
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
2
+ from rwkv.model import RWKV
3
  import gradio as gr
4
+ import os
5
+ import gc
6
+ import torch
7
  from datetime import datetime
8
  from huggingface_hub import hf_hub_download
9
  from pynvml import *
 
15
  <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a>
16
  <a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a>
17
  <a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a>
 
18
  '''
19
 
20
  os.environ["RWKV_JIT_ON"] = '1'
21
+ # if '1' then use CUDA kernel for seq mode (much faster)
22
+ os.environ["RWKV_CUDA_ON"] = '1'
23
 
 
24
  model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth")
25
+ model = RWKV(model=model_path, strategy='cuda fp16i8 *20 -> cuda fp16')
 
26
  pipeline = PIPELINE(model, "20B_tokenizer.json")
27
 
28
+ ########################################################################################################
29
+
30
  def infer(
31
  ctx,
32
  token_count=10,
33
  temperature=1.0,
34
  top_p=0.8,
35
+ presence_enalty=0.1,
36
+ count_penalty=0.1,
37
  ):
38
+ args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
39
+ alpha_frequency=float(count_penalty),
40
+ alpha_presence=float(presence_enalty),
41
+ token_ban=[0], # ban the generation of some tokens
42
+ token_stop=[]) # stop generation whenever you see any token here
43
 
44
  ctx = ctx.strip(' ')
45
  if ctx.endswith('\n'):
 
49
 
50
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
51
  print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
52
+
53
  all_tokens = []
54
  out_last = 0
55
  out_str = ''
 
70
  occurrence[token] = 1
71
  else:
72
  occurrence[token] += 1
73
+
74
  tmp = pipeline.decode(all_tokens[out_last:])
75
  if '\ufffd' not in tmp:
76
  out_str += tmp
 
110
  ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
111
  ]
112
 
113
+ infer_interface = gr.Interface(
 
114
  fn=infer,
115
+ description=f'''{desc} <b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
116
  allow_flagging="never",
117
  inputs=[
118
  gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
 
127
  cache_examples=False,
128
  ).queue()
129
 
130
+ ########################################################################################################
131
+
132
+ user = "Bob"
133
+ bot = "Alice"
134
+ interface = ":"
135
+
136
+ chat_intro = f'''
137
+ The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
138
+ {bot} is very intelligent, creative and friendly. \
139
+ She is unlikely to disagree with {user}, and she doesn't like to ask {user} questions. \
140
+ She also likes to tell {user} a lot about herself and her opinions, and she usually gives {user} kind, helpful and informative advices.
141
+
142
+ {user}{interface} Hello, how are you doing?
143
+
144
+ {bot}{interface} Hi {user}! Thanks, I'm fine. What about you?
145
+
146
+ {user}{interface} I am fine. It's nice to see you. Look, here is a store selling tea and juice.
147
+
148
+ {bot}{interface} Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!
149
+
150
+ {user}{interface} What is it?
151
+
152
+ {bot}{interface} Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its flavors are frequently sweet.
153
+
154
+ {user}{interface} Sounds tasty. I'll try it next time. Would you like to chat with me for a while?
155
+
156
+ {bot}{interface} Of course! I'm glad to answer your questions or give helpful advices. You know, I am confident with my expertise. So please go ahead!
157
+
158
+ '''
159
+
160
+ _, intro_state = model.forward(pipeline.encode(chat_intro), None)
161
+
162
+ def chat(
163
+ message: str,
164
+ history,
165
+ token_count=10,
166
+ temperature=1.0,
167
+ top_p=0.8,
168
+ presence_enalty=0.1,
169
+ count_penalty=0.1,
170
+ ):
171
+ args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
172
+ alpha_frequency=float(count_penalty),
173
+ alpha_presence=float(presence_enalty),
174
+ token_ban=[], # ban the generation of some tokens
175
+ token_stop=[]) # stop generation whenever you see any token here
176
+
177
+ message = message.strip(' ')
178
+ message = message.replace('\n', '')
179
+ ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
180
+
181
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
182
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
183
+
184
+ history = history or [[], intro_state, []] # [chat, state, all_tokens]
185
+
186
+ [chat_log, state, all_tokens] = history
187
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
188
+
189
+ begin = len(all_tokens)
190
+ out_last = begin
191
+ out_str: str = ''
192
+ occurrence = {}
193
+ for i in range(int(token_count)):
194
+ if i <= 0:
195
+ nl_bias = -float('inf')
196
+ elif i <= 30:
197
+ nl_bias = (i - 30) * 0.1
198
+ elif i <= 130:
199
+ nl_bias = 0
200
+ else:
201
+ nl_bias = (i - 130) * 0.25
202
+ out[187] += nl_bias
203
+ for n in occurrence:
204
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
205
+
206
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
207
+ next_tokens = [token]
208
+ if token == 0:
209
+ next_tokens = pipeline.encode('\n\n')
210
+ all_tokens += next_tokens
211
+
212
+ if token not in occurrence:
213
+ occurrence[token] = 1
214
+ else:
215
+ occurrence[token] += 1
216
+
217
+ out, state = model.forward(next_tokens, state)
218
+
219
+ tmp = pipeline.decode(all_tokens[out_last:])
220
+ if '\ufffd' not in tmp:
221
+ print(tmp, end='', flush=True)
222
+ out_last = begin + i + 1
223
+
224
+ out_str = pipeline.decode(all_tokens[begin:])
225
+ out_str = out_str.replace("\r\n", '\n').replace('\\n', '\n')
226
+
227
+ if '\n\n' in out_str:
228
+ break
229
+
230
+ gc.collect()
231
+ torch.cuda.empty_cache()
232
+
233
+ chat_log.append((message, out_str.strip()))
234
+ history = [chat_log, state, all_tokens]
235
+ return chat_log, history
236
+
237
+ chat_interface = gr.Interface(
238
+ fn=chat,
239
+ description=f'''You are {user}, bot is {bot}.''',
240
+ allow_flagging="never",
241
+ inputs = [
242
+ gr.Textbox(label="Message"),
243
+ "state",
244
+ gr.Slider(10, 1000, step=10, value=250), # token_count
245
+ gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
246
+ gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p
247
+ gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presence_penalty
248
+ gr.Slider(0.0, 1.0, step=0.1, value=0.2), # count_penalty
249
+ ],
250
+ outputs=[
251
+ gr.Chatbot(label="Chat Log", color_map=("blue", "pink")),
252
+ "state"
253
+ ]
254
+ ).queue()
255
+
256
+ ########################################################################################################
257
+
258
  demo = gr.TabbedInterface(
259
+ [infer_interface, chat_interface], ["Generative", "Chat"],
260
  title=title,
261
  )
262
 
263
  demo.queue(max_size=10)
264
+ demo.launch(share=True)