Singady commited on
Commit
857d4d1
1 Parent(s): 54a33c7

Upload NSFW.py

Browse files
Files changed (1) hide show
  1. NSFW.py +298 -0
NSFW.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, gc, copy, torch
3
+ from datetime import datetime
4
+ from huggingface_hub import hf_hub_download
5
+ from pynvml import *
6
+ nvmlInit()
7
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
8
+ ctx_limit = 1536
9
+ title = "RWKV-4-Raven-14B-v12-Eng98%-Other2%-20230523-ctx8192"
10
+
11
+ os.environ["RWKV_JIT_ON"] = '1'
12
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
13
+
14
+ from rwkv.model import RWKV
15
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename=f"{title}.pth")
16
+ model = RWKV(model=model_path, strategy='cuda fp16i8 *24 -> cuda fp16')
17
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
18
+ pipeline = PIPELINE(model, "20B_tokenizer.json")
19
+
20
+ def generate_prompt(instruction, input=None):
21
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
22
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
23
+ if input:
24
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
25
+ # Instruction:
26
+ {instruction}
27
+ # Input:
28
+ {input}
29
+ # Response:
30
+ """
31
+ else:
32
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
33
+ # Instruction:
34
+ {instruction}
35
+ # Response:
36
+ """
37
+
38
+ def evaluate(
39
+ instruction,
40
+ input=None,
41
+ token_count=200,
42
+ temperature=1.0,
43
+ top_p=0.7,
44
+ presencePenalty = 0.1,
45
+ countPenalty = 0.1,
46
+ ):
47
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
48
+ alpha_frequency = countPenalty,
49
+ alpha_presence = presencePenalty,
50
+ token_ban = [], # ban the generation of some tokens
51
+ token_stop = [0]) # stop generation whenever you see any token here
52
+
53
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
54
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
55
+ ctx = generate_prompt(instruction, input)
56
+
57
+ all_tokens = []
58
+ out_last = 0
59
+ out_str = ''
60
+ occurrence = {}
61
+ state = None
62
+ for i in range(int(token_count)):
63
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
64
+ for n in occurrence:
65
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
66
+
67
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
68
+ if token in args.token_stop:
69
+ break
70
+ all_tokens += [token]
71
+ if token not in occurrence:
72
+ occurrence[token] = 1
73
+ else:
74
+ occurrence[token] += 1
75
+
76
+ tmp = pipeline.decode(all_tokens[out_last:])
77
+ if '\ufffd' not in tmp:
78
+ out_str += tmp
79
+ yield out_str.strip()
80
+ out_last = i + 1
81
+
82
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
83
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
84
+ del out
85
+ del state
86
+ gc.collect()
87
+ torch.cuda.empty_cache()
88
+ yield out_str.strip()
89
+
90
+ examples = [
91
+ ["Tell me about ravens.", "", 300, 1.2, 0.5, 0.4, 0.4],
92
+ ["Write a python function to mine 1 BTC, with details and comments.", "", 300, 1.2, 0.5, 0.4, 0.4],
93
+ ["Write a song about ravens.", "", 300, 1.2, 0.5, 0.4, 0.4],
94
+ ["Explain the following metaphor: Life is like cats.", "", 300, 1.2, 0.5, 0.4, 0.4],
95
+ ["Write a story using the following information", "A man named Alex chops a tree down", 300, 1.2, 0.5, 0.4, 0.4],
96
+ ["Generate a list of adjectives that describe a person as brave.", "", 300, 1.2, 0.5, 0.4, 0.4],
97
+ ["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.4, 0.4],
98
+ ]
99
+
100
+ ##########################################################################
101
+
102
+ chat_intro = '''The following is a coherent verbose detailed conversation between <|user|> and an AI girl named <|bot|>.
103
+ <|user|>: Hi <|bot|>, Would you like to chat with me for a while?
104
+ <|bot|>: Hi <|user|>. Sure. What would you like to talk about? I'm listening.
105
+ '''
106
+
107
+ def user(message, chatbot):
108
+ chatbot = chatbot or []
109
+ # print(f"User: {message}")
110
+ return "", chatbot + [[message, None]]
111
+
112
+ def alternative(chatbot, history):
113
+ if not chatbot or not history:
114
+ return chatbot, history
115
+
116
+ chatbot[-1][1] = None
117
+ history[0] = copy.deepcopy(history[1])
118
+
119
+ return chatbot, history
120
+
121
+ def chat(
122
+ prompt,
123
+ user,
124
+ bot,
125
+ chatbot,
126
+ history,
127
+ temperature=1.0,
128
+ top_p=0.8,
129
+ presence_penalty=0.1,
130
+ count_penalty=0.1,
131
+ ):
132
+ args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
133
+ alpha_frequency=float(count_penalty),
134
+ alpha_presence=float(presence_penalty),
135
+ token_ban=[], # ban the generation of some tokens
136
+ token_stop=[]) # stop generation whenever you see any token here
137
+
138
+ if not chatbot:
139
+ return chatbot, history
140
+
141
+ message = chatbot[-1][0]
142
+ message = message.strip().replace('\r\n','\n').replace('\n\n','\n')
143
+ ctx = f"{user}: {message}\n\n{bot}:"
144
+
145
+ if not history:
146
+ prompt = prompt.replace("<|user|>", user.strip())
147
+ prompt = prompt.replace("<|bot|>", bot.strip())
148
+ prompt = prompt.strip()
149
+ prompt = f"\n{prompt}\n\n"
150
+
151
+ out, state = model.forward(pipeline.encode(prompt), None)
152
+ history = [state, None, []] # [state, state_pre, tokens]
153
+ # print("History reloaded.")
154
+
155
+ [state, _, all_tokens] = history
156
+ state_pre_0 = copy.deepcopy(state)
157
+
158
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
159
+ state_pre_1 = copy.deepcopy(state) # For recovery
160
+
161
+ # print("Bot:", end='')
162
+
163
+ begin = len(all_tokens)
164
+ out_last = begin
165
+ out_str: str = ''
166
+ occurrence = {}
167
+ for i in range(300):
168
+ if i <= 0:
169
+ nl_bias = -float('inf')
170
+ elif i <= 30:
171
+ nl_bias = (i - 30) * 0.1
172
+ elif i <= 130:
173
+ nl_bias = 0
174
+ else:
175
+ nl_bias = (i - 130) * 0.25
176
+ out[187] += nl_bias
177
+ for n in occurrence:
178
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
179
+
180
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
181
+ next_tokens = [token]
182
+ if token == 0:
183
+ next_tokens = pipeline.encode('\n\n')
184
+ all_tokens += next_tokens
185
+
186
+ if token not in occurrence:
187
+ occurrence[token] = 1
188
+ else:
189
+ occurrence[token] += 1
190
+
191
+ out, state = model.forward(next_tokens, state)
192
+
193
+ tmp = pipeline.decode(all_tokens[out_last:])
194
+ if '\ufffd' not in tmp:
195
+ # print(tmp, end='', flush=True)
196
+ out_last = begin + i + 1
197
+ out_str += tmp
198
+
199
+ chatbot[-1][1] = out_str.strip()
200
+ history = [state, all_tokens]
201
+ yield chatbot, history
202
+
203
+ out_str = pipeline.decode(all_tokens[begin:])
204
+ out_str = out_str.replace("\r\n", '\n').replace('\\n', '\n')
205
+
206
+ if '\n\n' in out_str:
207
+ break
208
+
209
+ # State recovery
210
+ if f'{user}:' in out_str or f'{bot}:' in out_str:
211
+ idx_user = out_str.find(f'{user}:')
212
+ idx_user = len(out_str) if idx_user == -1 else idx_user
213
+ idx_bot = out_str.find(f'{bot}:')
214
+ idx_bot = len(out_str) if idx_bot == -1 else idx_bot
215
+ idx = min(idx_user, idx_bot)
216
+
217
+ if idx < len(out_str):
218
+ out_str = f" {out_str[:idx].strip()}\n\n"
219
+ tokens = pipeline.encode(out_str)
220
+
221
+ all_tokens = all_tokens[:begin] + tokens
222
+ out, state = model.forward(tokens, state_pre_1)
223
+ break
224
+
225
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
226
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
227
+
228
+ gc.collect()
229
+ torch.cuda.empty_cache()
230
+
231
+ chatbot[-1][1] = out_str.strip()
232
+ history = [state, state_pre_0, all_tokens]
233
+ yield chatbot, history
234
+
235
+ ##########################################################################
236
+
237
+ with gr.Blocks(title=title) as demo:
238
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>🐦Raven - {title}</h1>\n</div>")
239
+ with gr.Tab("Instruct mode"):
240
+ gr.Markdown(f"Raven is [RWKV 14B](https://github.com/BlinkDL/ChatRWKV) 100% RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM) finetuned to follow instructions. *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}. Finetuned on alpaca, gpt4all, codealpaca and more. For best results, *** keep you prompt short and clear ***. <b>UPDATE: now with Chat (see above, as a tab) ==> turn off as of now due to VRAM leak caused by buggy code.</b>.")
241
+ with gr.Row():
242
+ with gr.Column():
243
+ instruction = gr.Textbox(lines=2, label="Instruction", value="Tell me about ravens.")
244
+ input = gr.Textbox(lines=2, label="Input", placeholder="none")
245
+ token_count = gr.Slider(10, 300, label="Max Tokens", step=10, value=300)
246
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
247
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5)
248
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.4)
249
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.4)
250
+ with gr.Column():
251
+ with gr.Row():
252
+ submit = gr.Button("Submit", variant="primary")
253
+ clear = gr.Button("Clear", variant="secondary")
254
+ output = gr.Textbox(label="Output", lines=5)
255
+ data = gr.Dataset(components=[instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
256
+ submit.click(evaluate, [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
257
+ clear.click(lambda: None, [], [output])
258
+ data.click(lambda x: x, [data], [instruction, input, token_count, temperature, top_p, presence_penalty, count_penalty])
259
+
260
+ # with gr.Tab("Chat (Experimental - Might be buggy - use ChatRWKV for reference)"):
261
+ # gr.Markdown(f'''<b>*** The length of response is restricted in this demo. Use ChatRWKV for longer generations. ***</b> Say "go on" or "continue" can sometimes continue the response. If you'd like to edit the scenario, make sure to follow the exact same format: empty lines between (and only between) different speakers. Changes only take effect after you press [Clear]. <b>The default "Bob" & "Alice" names work the best.</b>''', label="Description")
262
+ # with gr.Row():
263
+ # with gr.Column():
264
+ # chatbot = gr.Chatbot()
265
+ # state = gr.State()
266
+ # message = gr.Textbox(label="Message", value="Write me a python code to land on moon.")
267
+ # with gr.Row():
268
+ # send = gr.Button("Send", variant="primary")
269
+ # alt = gr.Button("Alternative", variant="secondary")
270
+ # clear = gr.Button("Clear", variant="secondary")
271
+ # with gr.Column():
272
+ # with gr.Row():
273
+ # user_name = gr.Textbox(lines=1, max_lines=1, label="User Name", value="Bob")
274
+ # bot_name = gr.Textbox(lines=1, max_lines=1, label="Bot Name", value="Alice")
275
+ # prompt = gr.Textbox(lines=10, max_lines=50, label="Scenario", value=chat_intro)
276
+ # temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.2)
277
+ # top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5)
278
+ # presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.4)
279
+ # count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.4)
280
+ # chat_inputs = [
281
+ # prompt,
282
+ # user_name,
283
+ # bot_name,
284
+ # chatbot,
285
+ # state,
286
+ # temperature,
287
+ # top_p,
288
+ # presence_penalty,
289
+ # count_penalty
290
+ # ]
291
+ # chat_outputs = [chatbot, state]
292
+ # message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
293
+ # send.click(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
294
+ # alt.click(alternative, [chatbot, state], [chatbot, state], queue=False).then(chat, chat_inputs, chat_outputs)
295
+ # clear.click(lambda: ([], None, ""), [], [chatbot, state, message], queue=False)
296
+
297
+ demo.queue(concurrency_count=1, max_size=10)
298
+ demo.launch(share=False)