BlinkDL commited on
Commit
0269c32
1 Parent(s): ea44723

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -302
app.py CHANGED
@@ -1,69 +1,59 @@
1
- import os, copy
2
- os.environ["RWKV_JIT_ON"] = '1'
3
- os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
4
- # make sure cuda dir is in the same level as modeling_rwkv.py
5
- from modeling_rwkv import RWKV
6
-
7
- import gc, re
8
  import gradio as gr
9
- import base64
10
- from io import BytesIO
11
- import torch
12
- import torch.nn.functional as F
13
  from datetime import datetime
14
- from transformers import CLIPImageProcessor
15
  from huggingface_hub import hf_hub_download
16
  from pynvml import *
17
  nvmlInit()
18
  gpu_h = nvmlDeviceGetHandleByIndex(0)
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- ctx_limit = 2500
22
  gen_limit = 500
23
  gen_limit_long = 800
24
- ENABLE_VISUAL = False
25
 
26
- ########################## text rwkv ################################################################
27
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
28
 
29
- title_v6 = "RWKV-x060-World-3B-v2.1-20240417-ctx4096"
30
- model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title_v6}.pth")
31
- # model_path_v6 = '/mnt/e/RWKV-Runner/models/rwkv-final-v6-2.1-3b' # conda activate torch2; cd /mnt/program/_RWKV_/_ref_/_gradio_/RWKV-Gradio-1; python app.py
32
- model_v6 = RWKV(model=model_path_v6, strategy='cuda fp16')
33
- pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
34
 
35
- args = model_v6.args
36
- eng_name = 'rwkv-x060-eng_single_round_qa-3B-20240516-ctx2048'
37
- chn_name = 'rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048'
 
38
 
39
- # state_eng_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{eng_name}.pth', map_location=torch.device('cpu'))
40
- # state_chn_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{chn_name}.pth', map_location=torch.device('cpu'))
41
 
 
 
42
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
43
- chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
44
- state_eng_raw = torch.load(eng_file, map_location=torch.device('cpu'))
45
- state_chn_raw = torch.load(chn_file, map_location=torch.device('cpu'))
46
-
47
  state_eng = [None] * args.n_layer * 3
 
 
 
 
48
  state_chn = [None] * args.n_layer * 3
 
 
 
 
 
 
49
  for i in range(args.n_layer):
50
- dd = model_v6.strategy[i]
51
  dev = dd.device
52
  atype = dd.atype
53
  state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
54
- state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
55
  state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
- state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
57
  state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
58
- state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
59
 
60
- penalty_decay = 0.996
 
 
61
 
62
- if ENABLE_VISUAL:
63
- title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
64
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
65
- model = RWKV(model=model_path, strategy='cuda fp16')
66
- pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
67
 
68
  def generate_prompt(instruction, input=""):
69
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -78,13 +68,15 @@ def qa_prompt(instruction):
78
  instruction = re.sub(r'\n+', '\n', instruction)
79
  return f"User: {instruction}\n\nAssistant:"""
80
 
 
 
81
  def evaluate(
82
  ctx,
83
- token_count=200,
84
  temperature=1.0,
85
- top_p=0.7,
86
- presencePenalty = 0.1,
87
- countPenalty = 0.1,
88
  ):
89
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
90
  alpha_frequency = countPenalty,
@@ -98,30 +90,22 @@ def evaluate(
98
  occurrence = {}
99
  state = None
100
  for i in range(int(token_count)):
101
- input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
102
- out, state = model_v6.forward(tokens=input_ids, state=state)
103
  for n in occurrence:
104
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
105
 
106
- token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
107
  if token in args.token_stop:
108
  break
109
  all_tokens += [token]
110
  for xxx in occurrence:
111
  occurrence[xxx] *= penalty_decay
112
-
113
- ttt = pipeline_v6.decode([token])
114
- www = 1
115
- if ttt in ' \t0123456789':
116
- www = 0
117
- #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
118
- # www = 0.5
119
  if token not in occurrence:
120
- occurrence[token] = www
121
  else:
122
- occurrence[token] += www
123
-
124
- tmp = pipeline_v6.decode(all_tokens[out_last:])
125
  if '\ufffd' not in tmp:
126
  out_str += tmp
127
  yield out_str.strip()
@@ -129,7 +113,7 @@ def evaluate(
129
 
130
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
131
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
132
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
133
  del out
134
  del state
135
  gc.collect()
@@ -138,11 +122,11 @@ def evaluate(
138
 
139
  def evaluate_eng(
140
  ctx,
141
- token_count=200,
142
  temperature=1.0,
143
- top_p=0.7,
144
- presencePenalty = 0.1,
145
- countPenalty = 0.1,
146
  ):
147
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
148
  alpha_frequency = countPenalty,
@@ -156,30 +140,22 @@ def evaluate_eng(
156
  occurrence = {}
157
  state = copy.deepcopy(state_eng)
158
  for i in range(int(token_count)):
159
- input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
160
- out, state = model_v6.forward(tokens=input_ids, state=state)
161
  for n in occurrence:
162
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
163
 
164
- token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
165
  if token in args.token_stop:
166
  break
167
  all_tokens += [token]
168
  for xxx in occurrence:
169
  occurrence[xxx] *= penalty_decay
170
-
171
- ttt = pipeline_v6.decode([token])
172
- www = 1
173
- if ttt in ' \t0123456789':
174
- www = 0
175
- #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
176
- # www = 0.5
177
  if token not in occurrence:
178
- occurrence[token] = www
179
  else:
180
- occurrence[token] += www
181
-
182
- tmp = pipeline_v6.decode(all_tokens[out_last:])
183
  if '\ufffd' not in tmp:
184
  out_str += tmp
185
  yield out_str.strip()
@@ -187,7 +163,7 @@ def evaluate_eng(
187
 
188
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
189
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
190
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
191
  del out
192
  del state
193
  gc.collect()
@@ -196,11 +172,11 @@ def evaluate_eng(
196
 
197
  def evaluate_chn(
198
  ctx,
199
- token_count=200,
200
  temperature=1.0,
201
- top_p=0.7,
202
- presencePenalty = 0.1,
203
- countPenalty = 0.1,
204
  ):
205
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
206
  alpha_frequency = countPenalty,
@@ -214,30 +190,22 @@ def evaluate_chn(
214
  occurrence = {}
215
  state = copy.deepcopy(state_chn)
216
  for i in range(int(token_count)):
217
- input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
218
- out, state = model_v6.forward(tokens=input_ids, state=state)
219
  for n in occurrence:
220
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
221
 
222
- token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
223
  if token in args.token_stop:
224
  break
225
  all_tokens += [token]
226
  for xxx in occurrence:
227
  occurrence[xxx] *= penalty_decay
228
-
229
- ttt = pipeline_v6.decode([token])
230
- www = 1
231
- if ttt in ' \t0123456789':
232
- www = 0
233
- #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
234
- # www = 0.5
235
  if token not in occurrence:
236
- occurrence[token] = www
237
  else:
238
- occurrence[token] += www
239
-
240
- tmp = pipeline_v6.decode(all_tokens[out_last:])
241
  if '\ufffd' not in tmp:
242
  out_str += tmp
243
  yield out_str.strip()
@@ -245,7 +213,57 @@ def evaluate_chn(
245
 
246
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
247
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
248
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  del out
250
  del state
251
  gc.collect()
@@ -259,7 +277,7 @@ examples = [
259
  [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), gen_limit, 1, 0.3, 0.5, 0.5],
260
  ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", gen_limit, 1, 0.3, 0.5, 0.5],
261
  ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.3, 0.5, 0.5],
262
- [generate_prompt("Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes."), 500, 1, 0.3, 0.5, 0.5],
263
  ['''Japanese: 春の初め、桜の花が満開になる頃、小さな町の片隅にある古びた神社の境��は、特別な雰囲気に包まれていた。\n\nEnglish:''', gen_limit, 1, 0.3, 0.5, 0.5],
264
  ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", gen_limit, 1, 0.3, 0.5, 0.5],
265
  ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", gen_limit, 1, 0.3, 0.5, 0.5],
@@ -281,156 +299,29 @@ examples_chn = [
281
  ["怎样写一个在火星上的吸血鬼的有趣故事?", gen_limit_long, 1, 0.2, 0.3, 0.3],
282
  ["比较苹果和谷歌的商业模式。", gen_limit_long, 1, 0.2, 0.3, 0.3],
283
  ["鱼会口渴吗?", gen_limit_long, 1, 0.2, 0.3, 0.3],
284
- ["以 JSON 格式列举北京的美食。", gen_limit_long, 1, 0.2, 0.3, 0.3],
285
  ["编写一个Bash脚本来检查磁盘使用情况,如果使用量过高则发送警报。", gen_limit_long, 1, 0.2, 0.3, 0.3],
286
  ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
287
  ]
288
 
289
- if ENABLE_VISUAL:
290
- ########################## visual rwkv ################################################################
291
- visual_title = 'ViusualRWKV-v5'
292
- rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
293
- vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
294
- vision_tower_name = 'openai/clip-vit-large-patch14-336'
295
-
296
- model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
297
- visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
298
-
299
- ##########################################################################
300
- from modeling_vision import VisionEncoder, VisionEncoderConfig
301
- config = VisionEncoderConfig(n_embd=model.args.n_embd,
302
- vision_tower_name=vision_tower_name,
303
- grid_size=-1)
304
- visual_encoder = VisionEncoder(config)
305
- vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
306
- vision_state_dict = torch.load(vision_local_path, map_location='cpu')
307
- visual_encoder.load_state_dict(vision_state_dict)
308
- image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
309
- visual_encoder = visual_encoder.to(device)
310
- ##########################################################################
311
- def visual_generate_prompt(instruction):
312
- instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
313
- return f"\n{instruction}\n\nAssistant:"
314
-
315
- def generate(
316
- ctx,
317
- image_state,
318
- token_count=200,
319
- temperature=1.0,
320
- top_p=0.1,
321
- presencePenalty = 0.0,
322
- countPenalty = 1.0,
323
- ):
324
- args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.1,
325
- alpha_frequency = 1.0,
326
- alpha_presence = 0.0,
327
- token_ban = [], # ban the generation of some tokens
328
- token_stop = [0, 261]) # stop generation whenever you see any token here
329
- ctx = ctx.strip()
330
- all_tokens = []
331
- out_last = 0
332
- out_str = ''
333
- occurrence = {}
334
- for i in range(int(token_count)):
335
- if i == 0:
336
- input_ids = pipeline.encode(ctx)[-ctx_limit:]
337
- out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
338
- else:
339
- input_ids = [token]
340
- out, state = visual_rwkv.forward(tokens=input_ids, state=state)
341
- for n in occurrence:
342
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
343
-
344
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
345
- if token in args.token_stop:
346
- break
347
- all_tokens += [token]
348
- for xxx in occurrence:
349
- occurrence[xxx] *= 0.994
350
- if token not in occurrence:
351
- occurrence[token] = 1
352
- else:
353
- occurrence[token] += 1
354
-
355
- tmp = pipeline.decode(all_tokens[out_last:])
356
- if '\ufffd' not in tmp:
357
- out_str += tmp
358
- yield out_str.strip()
359
- out_last = i + 1
360
-
361
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
362
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
363
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
364
- del out
365
- del state
366
- gc.collect()
367
- torch.cuda.empty_cache()
368
- yield out_str.strip()
369
-
370
-
371
- ##########################################################################
372
- cur_dir = os.path.dirname(os.path.abspath(__file__))
373
- visual_examples = [
374
- [
375
- f"{cur_dir}/examples_pizza.jpg",
376
- "What are steps to cook it?"
377
- ],
378
- [
379
- f"{cur_dir}/examples_bluejay.jpg",
380
- "what is the name of this bird?",
381
- ],
382
- [
383
- f"{cur_dir}/examples_woman_and_dog.png",
384
- "describe this image",
385
- ],
386
- ]
387
-
388
-
389
- def pil_image_to_base64(pil_image):
390
- buffered = BytesIO()
391
- pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
392
- # Encodes the image data into base64 format as a bytes object
393
- base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
394
- return base64_image
395
-
396
- image_cache = {}
397
- ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
398
- ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
399
- def compute_image_state(image):
400
- base64_image = pil_image_to_base64(image)
401
- if base64_image in image_cache:
402
- image_state = image_cache[base64_image]
403
- else:
404
- image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
405
- image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
406
- # apply layer norm to image feature, very important
407
- image_features = F.layer_norm(image_features,
408
- (image_features.shape[-1],),
409
- weight=ln0_weight,
410
- bias=ln0_bias)
411
- _, image_state = model.forward(embs=image_features, state=None)
412
- image_cache[base64_image] = image_state
413
- return image_state
414
-
415
- def chatbot(image, question):
416
- if image is None:
417
- yield "Please upload an image."
418
- return
419
- image_state = compute_image_state(image)
420
- input_text = visual_generate_prompt(question)
421
- for output in generate(input_text, image_state):
422
- yield output
423
-
424
-
425
- ##################################################################################################################
426
- with gr.Blocks(title=title_v6) as demo:
427
- gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
428
 
429
  with gr.Tab("=== Base Model (Raw Generation) ==="):
430
- gr.Markdown(f"This is [RWKV-6 World v2](https://huggingface.co/BlinkDL/rwkv-6-world) - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. Check [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.")
431
  with gr.Row():
432
  with gr.Column():
433
- prompt = gr.Textbox(lines=2, label="Prompt", value="Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.")
434
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
435
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
436
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
@@ -441,68 +332,10 @@ with gr.Blocks(title=title_v6) as demo:
441
  submit = gr.Button("Submit", variant="primary")
442
  clear = gr.Button("Clear", variant="secondary")
443
  output = gr.Textbox(label="Output", lines=30)
444
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
445
  submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
446
  clear.click(lambda: None, [], [output])
447
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
448
 
449
- with gr.Tab("=== English Q/A ==="):
450
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [English Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{eng_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
451
- with gr.Row():
452
- with gr.Column():
453
- prompt = gr.Textbox(lines=2, label="Prompt", value="How can I craft an engaging story featuring vampires on Mars?")
454
- token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
455
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
456
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
457
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
458
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
459
- with gr.Column():
460
- with gr.Row():
461
- submit = gr.Button("Submit", variant="primary")
462
- clear = gr.Button("Clear", variant="secondary")
463
- output = gr.Textbox(label="Output", lines=30)
464
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_eng, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
465
- submit.click(evaluate_eng, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
466
- clear.click(lambda: None, [], [output])
467
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
468
-
469
- with gr.Tab("=== Chinese Q/A ==="):
470
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [Chinese Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{chn_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
471
- with gr.Row():
472
- with gr.Column():
473
- prompt = gr.Textbox(lines=2, label="Prompt", value="怎样写一个在火星上的吸血鬼的有趣故事?")
474
- token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
475
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
476
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
477
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
478
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
479
- with gr.Column():
480
- with gr.Row():
481
- submit = gr.Button("Submit", variant="primary")
482
- clear = gr.Button("Clear", variant="secondary")
483
- output = gr.Textbox(label="Output", lines=30)
484
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
485
- submit.click(evaluate_chn, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
486
- clear.click(lambda: None, [], [output])
487
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
488
-
489
- if ENABLE_VISUAL:
490
- with gr.Tab("Visual RWKV-5 1.5B"):
491
- with gr.Row():
492
- with gr.Column():
493
- image = gr.Image(type='pil', label="Image")
494
- with gr.Column():
495
- prompt = gr.Textbox(lines=8, label="Prompt",
496
- value="Render a clear and concise summary of the photo.")
497
- with gr.Row():
498
- submit = gr.Button("Submit", variant="primary")
499
- clear = gr.Button("Clear", variant="secondary")
500
- with gr.Column():
501
- output = gr.Textbox(label="Output", lines=10)
502
- data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
503
- submit.click(chatbot, [image, prompt], [output])
504
- clear.click(lambda: None, [], [output])
505
- data.click(lambda x: x, [data], [image, prompt])
506
-
507
  demo.queue(concurrency_count=1, max_size=10)
508
- demo.launch(share=False)
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os, gc, copy, torch, re
 
 
 
3
  from datetime import datetime
 
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
  nvmlInit()
7
  gpu_h = nvmlDeviceGetHandleByIndex(0)
8
+ ctx_limit = 1024
 
 
9
  gen_limit = 500
10
  gen_limit_long = 800
11
+ title = "RWKV-x060-World-7B-v3-20241112-ctx4096"
12
 
13
+ os.environ["RWKV_JIT_ON"] = '1'
14
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
15
 
16
+ from rwkv.model import RWKV
 
 
 
 
17
 
18
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title}.pth")
19
+ model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
20
+ # model_path = '/mnt/e/RWKV-Runner/models/rwkv-final-v6-2.1-7b' # conda activate torch2; cd /mnt/program/_RWKV_/_ref_/_gradio_/RWKV-Gradio-2; python app_tab.py
21
+ # model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
22
 
23
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
24
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
25
 
26
+ args = model.args
27
+ eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
28
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
29
+ state_eng_raw = torch.load(eng_file)
 
 
 
30
  state_eng = [None] * args.n_layer * 3
31
+
32
+ chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
33
+ chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
34
+ state_chn_raw = torch.load(chn_file)
35
  state_chn = [None] * args.n_layer * 3
36
+
37
+ wyw_name = 'rwkv-x060-chn_文言文和古典名著_single_round_qa-7B-20240601-ctx2048'
38
+ wyw_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{wyw_name}.pth")
39
+ state_wyw_raw = torch.load(wyw_file)
40
+ state_wyw = [None] * args.n_layer * 3
41
+
42
  for i in range(args.n_layer):
43
+ dd = model.strategy[i]
44
  dev = dd.device
45
  atype = dd.atype
46
  state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
47
  state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
 
48
  state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
49
 
50
+ state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
51
+ state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
52
+ state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
53
 
54
+ state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
55
+ state_wyw[i*3+1] = state_wyw_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
+ state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
 
57
 
58
  def generate_prompt(instruction, input=""):
59
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
 
68
  instruction = re.sub(r'\n+', '\n', instruction)
69
  return f"User: {instruction}\n\nAssistant:"""
70
 
71
+ penalty_decay = 0.996
72
+
73
  def evaluate(
74
  ctx,
75
+ token_count=gen_limit,
76
  temperature=1.0,
77
+ top_p=0.3,
78
+ presencePenalty = 0.3,
79
+ countPenalty = 0.3,
80
  ):
81
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
82
  alpha_frequency = countPenalty,
 
90
  occurrence = {}
91
  state = None
92
  for i in range(int(token_count)):
93
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
94
  for n in occurrence:
95
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
96
 
97
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
98
  if token in args.token_stop:
99
  break
100
  all_tokens += [token]
101
  for xxx in occurrence:
102
  occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
103
  if token not in occurrence:
104
+ occurrence[token] = 1
105
  else:
106
+ occurrence[token] += 1
107
+
108
+ tmp = pipeline.decode(all_tokens[out_last:])
109
  if '\ufffd' not in tmp:
110
  out_str += tmp
111
  yield out_str.strip()
 
113
 
114
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
115
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
116
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
117
  del out
118
  del state
119
  gc.collect()
 
122
 
123
  def evaluate_eng(
124
  ctx,
125
+ token_count=gen_limit,
126
  temperature=1.0,
127
+ top_p=0.3,
128
+ presencePenalty=0.3,
129
+ countPenalty=0.3,
130
  ):
131
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
132
  alpha_frequency = countPenalty,
 
140
  occurrence = {}
141
  state = copy.deepcopy(state_eng)
142
  for i in range(int(token_count)):
143
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
144
  for n in occurrence:
145
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
146
 
147
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
148
  if token in args.token_stop:
149
  break
150
  all_tokens += [token]
151
  for xxx in occurrence:
152
  occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
153
  if token not in occurrence:
154
+ occurrence[token] = 1
155
  else:
156
+ occurrence[token] += 1
157
+
158
+ tmp = pipeline.decode(all_tokens[out_last:])
159
  if '\ufffd' not in tmp:
160
  out_str += tmp
161
  yield out_str.strip()
 
163
 
164
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
165
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
166
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
167
  del out
168
  del state
169
  gc.collect()
 
172
 
173
  def evaluate_chn(
174
  ctx,
175
+ token_count=gen_limit,
176
  temperature=1.0,
177
+ top_p=0.3,
178
+ presencePenalty=0.3,
179
+ countPenalty=0.3,
180
  ):
181
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
182
  alpha_frequency = countPenalty,
 
190
  occurrence = {}
191
  state = copy.deepcopy(state_chn)
192
  for i in range(int(token_count)):
193
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
194
  for n in occurrence:
195
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
196
 
197
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
198
  if token in args.token_stop:
199
  break
200
  all_tokens += [token]
201
  for xxx in occurrence:
202
  occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
203
  if token not in occurrence:
204
+ occurrence[token] = 1
205
  else:
206
+ occurrence[token] += 1
207
+
208
+ tmp = pipeline.decode(all_tokens[out_last:])
209
  if '\ufffd' not in tmp:
210
  out_str += tmp
211
  yield out_str.strip()
 
213
 
214
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
215
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
216
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
217
+ del out
218
+ del state
219
+ gc.collect()
220
+ torch.cuda.empty_cache()
221
+ yield out_str.strip()
222
+
223
+ def evaluate_wyw(
224
+ ctx,
225
+ token_count=gen_limit,
226
+ temperature=1.0,
227
+ top_p=0.3,
228
+ presencePenalty=0.3,
229
+ countPenalty=0.3,
230
+ ):
231
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
232
+ alpha_frequency = countPenalty,
233
+ alpha_presence = presencePenalty,
234
+ token_ban = [], # ban the generation of some tokens
235
+ token_stop = [0]) # stop generation whenever you see any token here
236
+ ctx = qa_prompt(ctx)
237
+ all_tokens = []
238
+ out_last = 0
239
+ out_str = ''
240
+ occurrence = {}
241
+ state = copy.deepcopy(state_wyw)
242
+ for i in range(int(token_count)):
243
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
244
+ for n in occurrence:
245
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
246
+
247
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
248
+ if token in args.token_stop:
249
+ break
250
+ all_tokens += [token]
251
+ for xxx in occurrence:
252
+ occurrence[xxx] *= penalty_decay
253
+ if token not in occurrence:
254
+ occurrence[token] = 1
255
+ else:
256
+ occurrence[token] += 1
257
+
258
+ tmp = pipeline.decode(all_tokens[out_last:])
259
+ if '\ufffd' not in tmp:
260
+ out_str += tmp
261
+ yield out_str.strip()
262
+ out_last = i + 1
263
+
264
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
265
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
266
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
267
  del out
268
  del state
269
  gc.collect()
 
277
  [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), gen_limit, 1, 0.3, 0.5, 0.5],
278
  ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", gen_limit, 1, 0.3, 0.5, 0.5],
279
  ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.3, 0.5, 0.5],
280
+ [generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), 500, 1, 0.3, 0.5, 0.5],
281
  ['''Japanese: 春の初め、桜の花が満開になる頃、小さな町の片隅にある古びた神社の境��は、特別な雰囲気に包まれていた。\n\nEnglish:''', gen_limit, 1, 0.3, 0.5, 0.5],
282
  ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", gen_limit, 1, 0.3, 0.5, 0.5],
283
  ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", gen_limit, 1, 0.3, 0.5, 0.5],
 
299
  ["怎样写一个在火星上的吸血鬼的有趣故事?", gen_limit_long, 1, 0.2, 0.3, 0.3],
300
  ["比较苹果和谷歌的商业模式。", gen_limit_long, 1, 0.2, 0.3, 0.3],
301
  ["鱼会口渴吗?", gen_limit_long, 1, 0.2, 0.3, 0.3],
302
+ ["以 JSON 格式解释冰箱是如何工作的。", gen_limit_long, 1, 0.2, 0.3, 0.3],
303
  ["编写一个Bash脚本来检查磁盘使用情况,如果使用量过高则发送警报。", gen_limit_long, 1, 0.2, 0.3, 0.3],
304
  ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
305
  ]
306
 
307
+ examples_wyw = [
308
+ ["我和前男友分手了", gen_limit_long, 1, 0.2, 0.3, 0.3],
309
+ ["量子计算机的原理", gen_limit_long, 1, 0.2, 0.3, 0.3],
310
+ ["李白和杜甫的结拜故事", gen_limit_long, 1, 0.2, 0.3, 0.3],
311
+ ["林黛玉和伏地魔的关系是什么?", gen_limit_long, 1, 0.2, 0.3, 0.3],
312
+ ["我被同事陷害了,帮我写一篇文言文骂他", gen_limit_long, 1, 0.2, 0.3, 0.3],
313
+ ]
314
+
315
+ ##########################################################################
316
+
317
+ with gr.Blocks(title=title) as demo:
318
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title}</h1>\n</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  with gr.Tab("=== Base Model (Raw Generation) ==="):
321
+ gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) base model. Supports 100+ world languages and code. RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [400+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
322
  with gr.Row():
323
  with gr.Column():
324
+ prompt = gr.Textbox(lines=2, label="Raw Input", value="Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.")
325
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
326
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
327
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
 
332
  submit = gr.Button("Submit", variant="primary")
333
  clear = gr.Button("Clear", variant="secondary")
334
  output = gr.Textbox(label="Output", lines=30)
335
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
336
  submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
337
  clear.click(lambda: None, [], [output])
338
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  demo.queue(concurrency_count=1, max_size=10)
341
+ demo.launch(share=False)