johnpaulbin BlinkDL commited on
Commit
8ff6ba7
0 Parent(s):

Duplicate from BlinkDL/RWKV-World-7B

Browse files

Co-authored-by: BlinkDL <BlinkDL@users.noreply.huggingface.co>

Files changed (5) hide show
  1. .gitattributes +34 -0
  2. 20B_tokenizer.json +0 -0
  3. README.md +14 -0
  4. app.py +301 -0
  5. requirements.txt +7 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
20B_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Raven RWKV 7B
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: BlinkDL/RWKV-World-7B
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, gc, copy, torch, re
3
+ from datetime import datetime
4
+ from huggingface_hub import hf_hub_download
5
+ from pynvml import *
6
+ nvmlInit()
7
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
8
+ ctx_limit = 1536
9
+ title = "RWKV-4-World-7B-v1-20230626-ctx4096"
10
+
11
+ os.environ["RWKV_JIT_ON"] = '1'
12
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
13
+
14
+ from rwkv.model import RWKV
15
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-world", filename=f"{title}.pth")
16
+ model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
17
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
18
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
19
+
20
+ def generate_prompt(instruction, input=None):
21
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n').replace('\n\n','\n')
22
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n').replace('\n\n','\n')
23
+ if input:
24
+ return f"""Instruction: {instruction}
25
+
26
+ Input: {input}
27
+
28
+ Response:"""
29
+ else:
30
+ return f"""Question: {instruction}
31
+
32
+ Answer:"""
33
+
34
+ def evaluate(
35
+ instruction,
36
+ input=None,
37
+ token_count=200,
38
+ temperature=1.0,
39
+ top_p=0.7,
40
+ presencePenalty = 0.1,
41
+ countPenalty = 0.1,
42
+ ):
43
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
44
+ alpha_frequency = countPenalty,
45
+ alpha_presence = presencePenalty,
46
+ token_ban = [], # ban the generation of some tokens
47
+ token_stop = [0]) # stop generation whenever you see any token here
48
+
49
+ instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
50
+ input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
51
+ ctx = generate_prompt(instruction, input)
52
+
53
+ all_tokens = []
54
+ out_last = 0
55
+ out_str = ''
56
+ occurrence = {}
57
+ state = None
58
+ for i in range(int(token_count)):
59
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
60
+ for n in occurrence:
61
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
62
+
63
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
64
+ if token in args.token_stop:
65
+ break
66
+ all_tokens += [token]
67
+ for xxx in occurrence:
68
+ occurrence[xxx] *= 0.996
69
+ if token not in occurrence:
70
+ occurrence[token] = 1
71
+ else:
72
+ occurrence[token] += 1
73
+
74
+ tmp = pipeline.decode(all_tokens[out_last:])
75
+ if '\ufffd' not in tmp:
76
+ out_str += tmp
77
+ yield out_str.strip()
78
+ out_last = i + 1
79
+ if '\n\n' in out_str:
80
+ break
81
+
82
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
83
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
84
+ del out
85
+ del state
86
+ gc.collect()
87
+ torch.cuda.empty_cache()
88
+ yield out_str.strip()
89
+
90
+ examples = [
91
+ ["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 300, 1.2, 0.5, 0.4, 0.4],
92
+ ["Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.", "", 300, 1.2, 0.5, 0.4, 0.4],
93
+ ["Write a song about ravens.", "", 300, 1.2, 0.5, 0.4, 0.4],
94
+ ["Explain the following metaphor: Life is like cats.", "", 300, 1.2, 0.5, 0.4, 0.4],
95
+ ["Write a story using the following information", "A man named Alex chops a tree down", 300, 1.2, 0.5, 0.4, 0.4],
96
+ ["Generate a list of adjectives that describe a person as brave.", "", 300, 1.2, 0.5, 0.4, 0.4],
97
+ ["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.4, 0.4],
98
+ ]
99
+
100
+ ##########################################################################
101
+
102
+ chat_intro = '''The following is a coherent verbose detailed conversation between <|user|> and an AI girl named <|bot|>.
103
+
104
+ <|user|>: Hi <|bot|>, Would you like to chat with me for a while?
105
+
106
+ <|bot|>: Hi <|user|>. Sure. What would you like to talk about? I'm listening.
107
+ '''
108
+
109
+ def user(message, chatbot):
110
+ chatbot = chatbot or []
111
+ # print(f"User: {message}")
112
+ return "", chatbot + [[message, None]]
113
+
114
+ def alternative(chatbot, history):
115
+ if not chatbot or not history:
116
+ return chatbot, history
117
+
118
+ chatbot[-1][1] = None
119
+ history[0] = copy.deepcopy(history[1])
120
+
121
+ return chatbot, history
122
+
123
+ def chat(
124
+ prompt,
125
+ user,
126
+ bot,
127
+ chatbot,
128
+ history,
129
+ temperature=1.0,
130
+ top_p=0.8,
131
+ presence_penalty=0.1,
132
+ count_penalty=0.1,
133
+ ):
134
+ args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
135
+ alpha_frequency=float(count_penalty),
136
+ alpha_presence=float(presence_penalty),
137
+ token_ban=[], # ban the generation of some tokens
138
+ token_stop=[]) # stop generation whenever you see any token here
139
+
140
+ if not chatbot:
141
+ return chatbot, history
142
+
143
+ message = chatbot[-1][0]
144
+ message = message.strip().replace('\r\n','\n').replace('\n\n','\n')
145
+ ctx = f"{user}: {message}\n\n{bot}:"
146
+
147
+ if not history:
148
+ prompt = prompt.replace("<|user|>", user.strip())
149
+ prompt = prompt.replace("<|bot|>", bot.strip())
150
+ prompt = prompt.strip()
151
+ prompt = f"\n{prompt}\n\n"
152
+
153
+ out, state = model.forward(pipeline.encode(prompt), None)
154
+ history = [state, None, []] # [state, state_pre, tokens]
155
+ # print("History reloaded.")
156
+
157
+ [state, _, all_tokens] = history
158
+ state_pre_0 = copy.deepcopy(state)
159
+
160
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
161
+ state_pre_1 = copy.deepcopy(state) # For recovery
162
+
163
+ # print("Bot:", end='')
164
+
165
+ begin = len(all_tokens)
166
+ out_last = begin
167
+ out_str: str = ''
168
+ occurrence = {}
169
+ for i in range(300):
170
+ if i <= 0:
171
+ nl_bias = -float('inf')
172
+ elif i <= 30:
173
+ nl_bias = (i - 30) * 0.1
174
+ elif i <= 130:
175
+ nl_bias = 0
176
+ else:
177
+ nl_bias = (i - 130) * 0.25
178
+ out[11] += nl_bias
179
+ for n in occurrence:
180
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
181
+
182
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
183
+ next_tokens = [token]
184
+ if token == 0:
185
+ next_tokens = pipeline.encode('\n\n')
186
+ all_tokens += next_tokens
187
+ for xxx in occurrence:
188
+ occurrence[xxx] *= 0.996
189
+ if token not in occurrence:
190
+ occurrence[token] = 1
191
+ else:
192
+ occurrence[token] += 1
193
+
194
+ out, state = model.forward(next_tokens, state)
195
+
196
+ tmp = pipeline.decode(all_tokens[out_last:])
197
+ if '\ufffd' not in tmp:
198
+ # print(tmp, end='', flush=True)
199
+ out_last = begin + i + 1
200
+ out_str += tmp
201
+
202
+ chatbot[-1][1] = out_str.strip()
203
+ history = [state, all_tokens]
204
+ yield chatbot, history
205
+
206
+ out_str = pipeline.decode(all_tokens[begin:])
207
+ out_str = out_str.replace("\r\n", '\n')
208
+
209
+ if '\n\n' in out_str:
210
+ break
211
+
212
+ # State recovery
213
+ if f'{user}:' in out_str or f'{bot}:' in out_str:
214
+ idx_user = out_str.find(f'{user}:')
215
+ idx_user = len(out_str) if idx_user == -1 else idx_user
216
+ idx_bot = out_str.find(f'{bot}:')
217
+ idx_bot = len(out_str) if idx_bot == -1 else idx_bot
218
+ idx = min(idx_user, idx_bot)
219
+
220
+ if idx < len(out_str):
221
+ out_str = f" {out_str[:idx].strip()}\n\n"
222
+ tokens = pipeline.encode(out_str)
223
+
224
+ all_tokens = all_tokens[:begin] + tokens
225
+ out, state = model.forward(tokens, state_pre_1)
226
+ break
227
+
228
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
229
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
230
+
231
+ gc.collect()
232
+ torch.cuda.empty_cache()
233
+
234
+ chatbot[-1][1] = out_str.strip()
235
+ history = [state, state_pre_0, all_tokens]
236
+ yield chatbot, history
237
+
238
+ ##########################################################################
239
+
240
+ with gr.Blocks(title=title) as demo:
241
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>🌍World - {title}</h1>\n</div>")
242
+ with gr.Tab("Instruct mode"):
243
+ gr.Markdown(f"World is [RWKV 7B](https://github.com/BlinkDL/ChatRWKV) 100% RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM) ***trained on 100+ world languages***. *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}. Finetuned on alpaca, gpt4all, codealpaca and more. For best results, *** keep you prompt short and clear ***.</b>.") # <b>UPDATE: now with Chat (see above, as a tab) ==> turn off as of now due to VRAM leak caused by buggy code.
244
+ with gr.Row():
245
+ with gr.Column():
246
+ instruction = gr.Textbox(lines=2, label="Instruction", value='東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。')
247
+ input = gr.Textbox(lines=2, label="Input", placeholder="none")
248
+ token_count = gr.Slider(10, 300, label="Max Tokens", step=10, value=300)
249
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
250
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5)
251
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.4)
252
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.4)
253
+ with gr.Column():
254
+ with gr.Row():
255
+ submit = gr.Button("Submit", variant="primary")
256
+ clear = gr.Button("Clear", variant="secondary")
257
+ output = gr.Textbox(label="Output", lines=5)
258
+ data = gr.Dataset(components=[instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
259
+ submit.click(evaluate, [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
260
+ clear.click(lambda: None, [], [output])
261
+ data.click(lambda x: x, [data], [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty])
262
+
263
+ # with gr.Tab("Chat (Experimental - Might be buggy - use ChatRWKV for reference)"):
264
+ # gr.Markdown(f'''<b>*** The length of response is restricted in this demo. Use ChatRWKV for longer generations. ***</b> Say "go on" or "continue" can sometimes continue the response. If you'd like to edit the scenario, make sure to follow the exact same format: empty lines between (and only between) different speakers. Changes only take effect after you press [Clear]. <b>The default "Bob" & "Alice" names work the best.</b>''', label="Description")
265
+ # with gr.Row():
266
+ # with gr.Column():
267
+ # chatbot = gr.Chatbot()
268
+ # state = gr.State()
269
+ # message = gr.Textbox(label="Message", value="Write me a python code to land on moon.")
270
+ # with gr.Row():
271
+ # send = gr.Button("Send", variant="primary")
272
+ # alt = gr.Button("Alternative", variant="secondary")
273
+ # clear = gr.Button("Clear", variant="secondary")
274
+ # with gr.Column():
275
+ # with gr.Row():
276
+ # user_name = gr.Textbox(lines=1, max_lines=1, label="User Name", value="Bob")
277
+ # bot_name = gr.Textbox(lines=1, max_lines=1, label="Bot Name", value="Alice")
278
+ # prompt = gr.Textbox(lines=10, max_lines=50, label="Scenario", value=chat_intro)
279
+ # temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
280
+ # top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5)
281
+ # presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.4)
282
+ # count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.4)
283
+ # chat_inputs = [
284
+ # prompt,
285
+ # user_name,
286
+ # bot_name,
287
+ # chatbot,
288
+ # state,
289
+ # temperature,
290
+ # top_p,
291
+ # presence_penalty,
292
+ # count_penalty
293
+ # ]
294
+ # chat_outputs = [chatbot, state]
295
+ # message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
296
+ # send.click(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
297
+ # alt.click(alternative, [chatbot, state], [chatbot, state], queue=False).then(chat, chat_inputs, chat_outputs)
298
+ # clear.click(lambda: ([], None, ""), [], [chatbot, state, message], queue=False)
299
+
300
+ demo.queue(concurrency_count=1, max_size=10)
301
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ ninja
3
+ tokenizers
4
+ rwkv==0.7.5
5
+ pynvml
6
+ huggingface_hub
7
+ gradio>=3.17.1