BlinkDL commited on
Commit
c47e784
1 Parent(s): 0269c32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -135
app.py CHANGED
@@ -1,59 +1,69 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import os, gc, copy, torch, re
 
 
 
3
  from datetime import datetime
 
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
  nvmlInit()
7
  gpu_h = nvmlDeviceGetHandleByIndex(0)
8
- ctx_limit = 1024
 
 
9
  gen_limit = 500
10
  gen_limit_long = 800
11
- title = "RWKV-x060-World-7B-v3-20241112-ctx4096"
12
 
13
- os.environ["RWKV_JIT_ON"] = '1'
14
- os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
15
 
16
- from rwkv.model import RWKV
 
 
 
 
17
 
18
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title}.pth")
19
- model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
20
- # model_path = '/mnt/e/RWKV-Runner/models/rwkv-final-v6-2.1-7b' # conda activate torch2; cd /mnt/program/_RWKV_/_ref_/_gradio_/RWKV-Gradio-2; python app_tab.py
21
- # model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
22
 
23
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
24
- pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
25
 
26
- args = model.args
27
- eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
28
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
29
- state_eng_raw = torch.load(eng_file)
30
- state_eng = [None] * args.n_layer * 3
31
-
32
- chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
33
  chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
34
- state_chn_raw = torch.load(chn_file)
35
- state_chn = [None] * args.n_layer * 3
36
-
37
- wyw_name = 'rwkv-x060-chn_文言文和古典名著_single_round_qa-7B-20240601-ctx2048'
38
- wyw_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{wyw_name}.pth")
39
- state_wyw_raw = torch.load(wyw_file)
40
- state_wyw = [None] * args.n_layer * 3
41
 
 
 
42
  for i in range(args.n_layer):
43
- dd = model.strategy[i]
44
  dev = dd.device
45
  atype = dd.atype
46
  state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
47
- state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
48
- state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
49
-
50
  state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
51
  state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
 
52
  state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
53
 
54
- state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
55
- state_wyw[i*3+1] = state_wyw_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
- state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
 
 
 
57
 
58
  def generate_prompt(instruction, input=""):
59
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -68,15 +78,13 @@ def qa_prompt(instruction):
68
  instruction = re.sub(r'\n+', '\n', instruction)
69
  return f"User: {instruction}\n\nAssistant:"""
70
 
71
- penalty_decay = 0.996
72
-
73
  def evaluate(
74
  ctx,
75
- token_count=gen_limit,
76
  temperature=1.0,
77
- top_p=0.3,
78
- presencePenalty = 0.3,
79
- countPenalty = 0.3,
80
  ):
81
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
82
  alpha_frequency = countPenalty,
@@ -90,22 +98,30 @@ def evaluate(
90
  occurrence = {}
91
  state = None
92
  for i in range(int(token_count)):
93
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
94
  for n in occurrence:
95
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
96
 
97
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
98
  if token in args.token_stop:
99
  break
100
  all_tokens += [token]
101
  for xxx in occurrence:
102
  occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
103
  if token not in occurrence:
104
- occurrence[token] = 1
105
  else:
106
- occurrence[token] += 1
107
-
108
- tmp = pipeline.decode(all_tokens[out_last:])
109
  if '\ufffd' not in tmp:
110
  out_str += tmp
111
  yield out_str.strip()
@@ -113,7 +129,7 @@ def evaluate(
113
 
114
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
115
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
116
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
117
  del out
118
  del state
119
  gc.collect()
@@ -122,11 +138,11 @@ def evaluate(
122
 
123
  def evaluate_eng(
124
  ctx,
125
- token_count=gen_limit,
126
  temperature=1.0,
127
- top_p=0.3,
128
- presencePenalty=0.3,
129
- countPenalty=0.3,
130
  ):
131
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
132
  alpha_frequency = countPenalty,
@@ -140,22 +156,30 @@ def evaluate_eng(
140
  occurrence = {}
141
  state = copy.deepcopy(state_eng)
142
  for i in range(int(token_count)):
143
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
144
  for n in occurrence:
145
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
146
 
147
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
148
  if token in args.token_stop:
149
  break
150
  all_tokens += [token]
151
  for xxx in occurrence:
152
  occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
153
  if token not in occurrence:
154
- occurrence[token] = 1
155
  else:
156
- occurrence[token] += 1
157
-
158
- tmp = pipeline.decode(all_tokens[out_last:])
159
  if '\ufffd' not in tmp:
160
  out_str += tmp
161
  yield out_str.strip()
@@ -163,7 +187,7 @@ def evaluate_eng(
163
 
164
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
165
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
166
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
167
  del out
168
  del state
169
  gc.collect()
@@ -172,11 +196,11 @@ def evaluate_eng(
172
 
173
  def evaluate_chn(
174
  ctx,
175
- token_count=gen_limit,
176
  temperature=1.0,
177
- top_p=0.3,
178
- presencePenalty=0.3,
179
- countPenalty=0.3,
180
  ):
181
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
182
  alpha_frequency = countPenalty,
@@ -190,22 +214,30 @@ def evaluate_chn(
190
  occurrence = {}
191
  state = copy.deepcopy(state_chn)
192
  for i in range(int(token_count)):
193
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
 
194
  for n in occurrence:
195
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
196
 
197
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
198
  if token in args.token_stop:
199
  break
200
  all_tokens += [token]
201
  for xxx in occurrence:
202
  occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
203
  if token not in occurrence:
204
- occurrence[token] = 1
205
  else:
206
- occurrence[token] += 1
207
-
208
- tmp = pipeline.decode(all_tokens[out_last:])
209
  if '\ufffd' not in tmp:
210
  out_str += tmp
211
  yield out_str.strip()
@@ -213,57 +245,7 @@ def evaluate_chn(
213
 
214
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
215
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
216
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
217
- del out
218
- del state
219
- gc.collect()
220
- torch.cuda.empty_cache()
221
- yield out_str.strip()
222
-
223
- def evaluate_wyw(
224
- ctx,
225
- token_count=gen_limit,
226
- temperature=1.0,
227
- top_p=0.3,
228
- presencePenalty=0.3,
229
- countPenalty=0.3,
230
- ):
231
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
232
- alpha_frequency = countPenalty,
233
- alpha_presence = presencePenalty,
234
- token_ban = [], # ban the generation of some tokens
235
- token_stop = [0]) # stop generation whenever you see any token here
236
- ctx = qa_prompt(ctx)
237
- all_tokens = []
238
- out_last = 0
239
- out_str = ''
240
- occurrence = {}
241
- state = copy.deepcopy(state_wyw)
242
- for i in range(int(token_count)):
243
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
244
- for n in occurrence:
245
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
246
-
247
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
248
- if token in args.token_stop:
249
- break
250
- all_tokens += [token]
251
- for xxx in occurrence:
252
- occurrence[xxx] *= penalty_decay
253
- if token not in occurrence:
254
- occurrence[token] = 1
255
- else:
256
- occurrence[token] += 1
257
-
258
- tmp = pipeline.decode(all_tokens[out_last:])
259
- if '\ufffd' not in tmp:
260
- out_str += tmp
261
- yield out_str.strip()
262
- out_last = i + 1
263
-
264
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
265
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
266
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
267
  del out
268
  del state
269
  gc.collect()
@@ -292,36 +274,163 @@ examples_eng = [
292
  ["Write an outline for a fantasy novel where dreams can alter reality.", gen_limit_long, 1, 0.2, 0.3, 0.3],
293
  ["Can fish get thirsty?", gen_limit_long, 1, 0.2, 0.3, 0.3],
294
  ["Write a Bash script to check disk usage and send alerts if it's too high.", gen_limit_long, 1, 0.2, 0.3, 0.3],
295
- ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", gen_limit_long, 1, 0.2, 0.3, 0.3],
296
  ]
297
 
298
  examples_chn = [
299
  ["怎样写一个在火星上的吸血鬼的有趣故事?", gen_limit_long, 1, 0.2, 0.3, 0.3],
300
  ["比较苹果和谷歌的商业模式。", gen_limit_long, 1, 0.2, 0.3, 0.3],
301
  ["鱼会口渴吗?", gen_limit_long, 1, 0.2, 0.3, 0.3],
302
- ["以 JSON 格式解释冰箱是如何工作的。", gen_limit_long, 1, 0.2, 0.3, 0.3],
303
  ["编写一个Bash脚本来检查磁盘使用情况,如果使用量过高则发送警报。", gen_limit_long, 1, 0.2, 0.3, 0.3],
304
  ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
305
  ]
306
 
307
- examples_wyw = [
308
- ["我和前男友分手了", gen_limit_long, 1, 0.2, 0.3, 0.3],
309
- ["量子计算机的原理", gen_limit_long, 1, 0.2, 0.3, 0.3],
310
- ["李白和杜甫的结拜故事", gen_limit_long, 1, 0.2, 0.3, 0.3],
311
- ["林黛玉和伏地魔的关系是什么?", gen_limit_long, 1, 0.2, 0.3, 0.3],
312
- ["我被同事陷害了,帮我写一篇文言文骂他", gen_limit_long, 1, 0.2, 0.3, 0.3],
313
- ]
314
-
315
- ##########################################################################
316
-
317
- with gr.Blocks(title=title) as demo:
318
- gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title}</h1>\n</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  with gr.Tab("=== Base Model (Raw Generation) ==="):
321
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) base model. Supports 100+ world languages and code. RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [400+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
322
  with gr.Row():
323
  with gr.Column():
324
- prompt = gr.Textbox(lines=2, label="Raw Input", value="Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.")
325
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
326
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
327
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
@@ -332,10 +441,68 @@ with gr.Blocks(title=title) as demo:
332
  submit = gr.Button("Submit", variant="primary")
333
  clear = gr.Button("Clear", variant="secondary")
334
  output = gr.Textbox(label="Output", lines=30)
335
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
336
  submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
337
  clear.click(lambda: None, [], [output])
338
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  demo.queue(concurrency_count=1, max_size=10)
341
- demo.launch(share=False)
 
1
+ import os, copy
2
+ os.environ["RWKV_JIT_ON"] = '1'
3
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
4
+ # make sure cuda dir is in the same level as modeling_rwkv.py
5
+ from modeling_rwkv import RWKV
6
+
7
+ import gc, re
8
  import gradio as gr
9
+ import base64
10
+ from io import BytesIO
11
+ import torch
12
+ import torch.nn.functional as F
13
  from datetime import datetime
14
+ from transformers import CLIPImageProcessor
15
  from huggingface_hub import hf_hub_download
16
  from pynvml import *
17
  nvmlInit()
18
  gpu_h = nvmlDeviceGetHandleByIndex(0)
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ ctx_limit = 2500
22
  gen_limit = 500
23
  gen_limit_long = 800
24
+ ENABLE_VISUAL = False
25
 
26
+ ########################## text rwkv ################################################################
27
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
28
 
29
+ title_v6 = "RWKV-x060-World-3B-v2.1-20240417-ctx4096"
30
+ model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title_v6}.pth")
31
+ # model_path_v6 = '/mnt/e/RWKV-Runner/models/rwkv-final-v6-2.1-3b' # conda activate torch2; cd /mnt/program/_RWKV_/_ref_/_gradio_/RWKV-Gradio-1; python app.py
32
+ model_v6 = RWKV(model=model_path_v6, strategy='cuda fp16')
33
+ pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
34
 
35
+ args = model_v6.args
36
+ eng_name = 'rwkv-x060-eng_single_round_qa-3B-20240516-ctx2048'
37
+ chn_name = 'rwkv-x060-chn_single_round_qa-3B-20240516-ctx2048'
 
38
 
39
+ # state_eng_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{eng_name}.pth', map_location=torch.device('cpu'))
40
+ # state_chn_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{chn_name}.pth', map_location=torch.device('cpu'))
41
 
 
 
42
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
 
 
 
 
43
  chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
44
+ state_eng_raw = torch.load(eng_file, map_location=torch.device('cpu'))
45
+ state_chn_raw = torch.load(chn_file, map_location=torch.device('cpu'))
 
 
 
 
 
46
 
47
+ state_eng = [None] * args.n_layer * 3
48
+ state_chn = [None] * args.n_layer * 3
49
  for i in range(args.n_layer):
50
+ dd = model_v6.strategy[i]
51
  dev = dd.device
52
  atype = dd.atype
53
  state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
 
 
54
  state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
55
+ state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
  state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
57
+ state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
58
  state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
59
 
60
+ penalty_decay = 0.996
61
+
62
+ if ENABLE_VISUAL:
63
+ title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
64
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
65
+ model = RWKV(model=model_path, strategy='cuda fp16')
66
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
67
 
68
  def generate_prompt(instruction, input=""):
69
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
 
78
  instruction = re.sub(r'\n+', '\n', instruction)
79
  return f"User: {instruction}\n\nAssistant:"""
80
 
 
 
81
  def evaluate(
82
  ctx,
83
+ token_count=200,
84
  temperature=1.0,
85
+ top_p=0.7,
86
+ presencePenalty = 0.1,
87
+ countPenalty = 0.1,
88
  ):
89
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
90
  alpha_frequency = countPenalty,
 
98
  occurrence = {}
99
  state = None
100
  for i in range(int(token_count)):
101
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
102
+ out, state = model_v6.forward(tokens=input_ids, state=state)
103
  for n in occurrence:
104
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
105
 
106
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
107
  if token in args.token_stop:
108
  break
109
  all_tokens += [token]
110
  for xxx in occurrence:
111
  occurrence[xxx] *= penalty_decay
112
+
113
+ ttt = pipeline_v6.decode([token])
114
+ www = 1
115
+ if ttt in ' \t0123456789':
116
+ www = 0
117
+ #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
118
+ # www = 0.5
119
  if token not in occurrence:
120
+ occurrence[token] = www
121
  else:
122
+ occurrence[token] += www
123
+
124
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
125
  if '\ufffd' not in tmp:
126
  out_str += tmp
127
  yield out_str.strip()
 
129
 
130
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
131
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
132
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
133
  del out
134
  del state
135
  gc.collect()
 
138
 
139
  def evaluate_eng(
140
  ctx,
141
+ token_count=200,
142
  temperature=1.0,
143
+ top_p=0.7,
144
+ presencePenalty = 0.1,
145
+ countPenalty = 0.1,
146
  ):
147
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
148
  alpha_frequency = countPenalty,
 
156
  occurrence = {}
157
  state = copy.deepcopy(state_eng)
158
  for i in range(int(token_count)):
159
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
160
+ out, state = model_v6.forward(tokens=input_ids, state=state)
161
  for n in occurrence:
162
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
163
 
164
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
165
  if token in args.token_stop:
166
  break
167
  all_tokens += [token]
168
  for xxx in occurrence:
169
  occurrence[xxx] *= penalty_decay
170
+
171
+ ttt = pipeline_v6.decode([token])
172
+ www = 1
173
+ if ttt in ' \t0123456789':
174
+ www = 0
175
+ #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
176
+ # www = 0.5
177
  if token not in occurrence:
178
+ occurrence[token] = www
179
  else:
180
+ occurrence[token] += www
181
+
182
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
183
  if '\ufffd' not in tmp:
184
  out_str += tmp
185
  yield out_str.strip()
 
187
 
188
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
189
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
190
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
191
  del out
192
  del state
193
  gc.collect()
 
196
 
197
  def evaluate_chn(
198
  ctx,
199
+ token_count=200,
200
  temperature=1.0,
201
+ top_p=0.7,
202
+ presencePenalty = 0.1,
203
+ countPenalty = 0.1,
204
  ):
205
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
206
  alpha_frequency = countPenalty,
 
214
  occurrence = {}
215
  state = copy.deepcopy(state_chn)
216
  for i in range(int(token_count)):
217
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
218
+ out, state = model_v6.forward(tokens=input_ids, state=state)
219
  for n in occurrence:
220
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
221
 
222
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
223
  if token in args.token_stop:
224
  break
225
  all_tokens += [token]
226
  for xxx in occurrence:
227
  occurrence[xxx] *= penalty_decay
228
+
229
+ ttt = pipeline_v6.decode([token])
230
+ www = 1
231
+ if ttt in ' \t0123456789':
232
+ www = 0
233
+ #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
234
+ # www = 0.5
235
  if token not in occurrence:
236
+ occurrence[token] = www
237
  else:
238
+ occurrence[token] += www
239
+
240
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
241
  if '\ufffd' not in tmp:
242
  out_str += tmp
243
  yield out_str.strip()
 
245
 
246
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
247
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
248
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  del out
250
  del state
251
  gc.collect()
 
274
  ["Write an outline for a fantasy novel where dreams can alter reality.", gen_limit_long, 1, 0.2, 0.3, 0.3],
275
  ["Can fish get thirsty?", gen_limit_long, 1, 0.2, 0.3, 0.3],
276
  ["Write a Bash script to check disk usage and send alerts if it's too high.", gen_limit_long, 1, 0.2, 0.3, 0.3],
277
+ ["Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes.", gen_limit_long, 1, 0.2, 0.3, 0.3],
278
  ]
279
 
280
  examples_chn = [
281
  ["怎样写一个在火星上的吸血鬼的有趣故事?", gen_limit_long, 1, 0.2, 0.3, 0.3],
282
  ["比较苹果和谷歌的商业模式。", gen_limit_long, 1, 0.2, 0.3, 0.3],
283
  ["鱼会口渴吗?", gen_limit_long, 1, 0.2, 0.3, 0.3],
284
+ ["以 JSON 格式列举���京的美食。", gen_limit_long, 1, 0.2, 0.3, 0.3],
285
  ["编写一个Bash脚本来检查磁盘使用情况,如果使用量过高则发送警报。", gen_limit_long, 1, 0.2, 0.3, 0.3],
286
  ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
287
  ]
288
 
289
+ if ENABLE_VISUAL:
290
+ ########################## visual rwkv ################################################################
291
+ visual_title = 'ViusualRWKV-v5'
292
+ rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
293
+ vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
294
+ vision_tower_name = 'openai/clip-vit-large-patch14-336'
295
+
296
+ model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
297
+ visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
298
+
299
+ ##########################################################################
300
+ from modeling_vision import VisionEncoder, VisionEncoderConfig
301
+ config = VisionEncoderConfig(n_embd=model.args.n_embd,
302
+ vision_tower_name=vision_tower_name,
303
+ grid_size=-1)
304
+ visual_encoder = VisionEncoder(config)
305
+ vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
306
+ vision_state_dict = torch.load(vision_local_path, map_location='cpu')
307
+ visual_encoder.load_state_dict(vision_state_dict)
308
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
309
+ visual_encoder = visual_encoder.to(device)
310
+ ##########################################################################
311
+ def visual_generate_prompt(instruction):
312
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
313
+ return f"\n{instruction}\n\nAssistant:"
314
+
315
+ def generate(
316
+ ctx,
317
+ image_state,
318
+ token_count=200,
319
+ temperature=1.0,
320
+ top_p=0.1,
321
+ presencePenalty = 0.0,
322
+ countPenalty = 1.0,
323
+ ):
324
+ args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.1,
325
+ alpha_frequency = 1.0,
326
+ alpha_presence = 0.0,
327
+ token_ban = [], # ban the generation of some tokens
328
+ token_stop = [0, 261]) # stop generation whenever you see any token here
329
+ ctx = ctx.strip()
330
+ all_tokens = []
331
+ out_last = 0
332
+ out_str = ''
333
+ occurrence = {}
334
+ for i in range(int(token_count)):
335
+ if i == 0:
336
+ input_ids = pipeline.encode(ctx)[-ctx_limit:]
337
+ out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
338
+ else:
339
+ input_ids = [token]
340
+ out, state = visual_rwkv.forward(tokens=input_ids, state=state)
341
+ for n in occurrence:
342
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
343
+
344
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
345
+ if token in args.token_stop:
346
+ break
347
+ all_tokens += [token]
348
+ for xxx in occurrence:
349
+ occurrence[xxx] *= 0.994
350
+ if token not in occurrence:
351
+ occurrence[token] = 1
352
+ else:
353
+ occurrence[token] += 1
354
+
355
+ tmp = pipeline.decode(all_tokens[out_last:])
356
+ if '\ufffd' not in tmp:
357
+ out_str += tmp
358
+ yield out_str.strip()
359
+ out_last = i + 1
360
+
361
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
362
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
363
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
364
+ del out
365
+ del state
366
+ gc.collect()
367
+ torch.cuda.empty_cache()
368
+ yield out_str.strip()
369
+
370
+
371
+ ##########################################################################
372
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
373
+ visual_examples = [
374
+ [
375
+ f"{cur_dir}/examples_pizza.jpg",
376
+ "What are steps to cook it?"
377
+ ],
378
+ [
379
+ f"{cur_dir}/examples_bluejay.jpg",
380
+ "what is the name of this bird?",
381
+ ],
382
+ [
383
+ f"{cur_dir}/examples_woman_and_dog.png",
384
+ "describe this image",
385
+ ],
386
+ ]
387
+
388
+
389
+ def pil_image_to_base64(pil_image):
390
+ buffered = BytesIO()
391
+ pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
392
+ # Encodes the image data into base64 format as a bytes object
393
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
394
+ return base64_image
395
+
396
+ image_cache = {}
397
+ ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
398
+ ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
399
+ def compute_image_state(image):
400
+ base64_image = pil_image_to_base64(image)
401
+ if base64_image in image_cache:
402
+ image_state = image_cache[base64_image]
403
+ else:
404
+ image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
405
+ image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
406
+ # apply layer norm to image feature, very important
407
+ image_features = F.layer_norm(image_features,
408
+ (image_features.shape[-1],),
409
+ weight=ln0_weight,
410
+ bias=ln0_bias)
411
+ _, image_state = model.forward(embs=image_features, state=None)
412
+ image_cache[base64_image] = image_state
413
+ return image_state
414
+
415
+ def chatbot(image, question):
416
+ if image is None:
417
+ yield "Please upload an image."
418
+ return
419
+ image_state = compute_image_state(image)
420
+ input_text = visual_generate_prompt(question)
421
+ for output in generate(input_text, image_state):
422
+ yield output
423
+
424
+
425
+ ##################################################################################################################
426
+ with gr.Blocks(title=title_v6) as demo:
427
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
428
 
429
  with gr.Tab("=== Base Model (Raw Generation) ==="):
430
+ gr.Markdown(f"This is [RWKV-6 World v2](https://huggingface.co/BlinkDL/rwkv-6-world) - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. Check [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.")
431
  with gr.Row():
432
  with gr.Column():
433
+ prompt = gr.Textbox(lines=2, label="Prompt", value="Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.")
434
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
435
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
436
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
 
441
  submit = gr.Button("Submit", variant="primary")
442
  clear = gr.Button("Clear", variant="secondary")
443
  output = gr.Textbox(label="Output", lines=30)
444
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
445
  submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
446
  clear.click(lambda: None, [], [output])
447
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
448
 
449
+ with gr.Tab("=== English Q/A ==="):
450
+ gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [English Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{eng_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
451
+ with gr.Row():
452
+ with gr.Column():
453
+ prompt = gr.Textbox(lines=2, label="Prompt", value="How can I craft an engaging story featuring vampires on Mars?")
454
+ token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
455
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
456
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
457
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
458
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
459
+ with gr.Column():
460
+ with gr.Row():
461
+ submit = gr.Button("Submit", variant="primary")
462
+ clear = gr.Button("Clear", variant="secondary")
463
+ output = gr.Textbox(label="Output", lines=30)
464
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_eng, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
465
+ submit.click(evaluate_eng, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
466
+ clear.click(lambda: None, [], [output])
467
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
468
+
469
+ with gr.Tab("=== Chinese Q/A ==="):
470
+ gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [Chinese Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{chn_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
471
+ with gr.Row():
472
+ with gr.Column():
473
+ prompt = gr.Textbox(lines=2, label="Prompt", value="怎样写一个在火星上的吸血鬼的有趣故事?")
474
+ token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
475
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
476
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
477
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
478
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
479
+ with gr.Column():
480
+ with gr.Row():
481
+ submit = gr.Button("Submit", variant="primary")
482
+ clear = gr.Button("Clear", variant="secondary")
483
+ output = gr.Textbox(label="Output", lines=30)
484
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
485
+ submit.click(evaluate_chn, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
486
+ clear.click(lambda: None, [], [output])
487
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
488
+
489
+ if ENABLE_VISUAL:
490
+ with gr.Tab("Visual RWKV-5 1.5B"):
491
+ with gr.Row():
492
+ with gr.Column():
493
+ image = gr.Image(type='pil', label="Image")
494
+ with gr.Column():
495
+ prompt = gr.Textbox(lines=8, label="Prompt",
496
+ value="Render a clear and concise summary of the photo.")
497
+ with gr.Row():
498
+ submit = gr.Button("Submit", variant="primary")
499
+ clear = gr.Button("Clear", variant="secondary")
500
+ with gr.Column():
501
+ output = gr.Textbox(label="Output", lines=10)
502
+ data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
503
+ submit.click(chatbot, [image, prompt], [output])
504
+ clear.click(lambda: None, [], [output])
505
+ data.click(lambda x: x, [data], [image, prompt])
506
+
507
  demo.queue(concurrency_count=1, max_size=10)
508
+ demo.launch(share=False)