Hazzzardous commited on
Commit
187c661
·
1 Parent(s): d43a442

add config option

Browse files
Files changed (2) hide show
  1. app.py +45 -41
  2. config.py +63 -0
app.py CHANGED
@@ -22,7 +22,7 @@ import codecs
22
  from ast import literal_eval
23
  from datetime import datetime
24
  from rwkvstic.load import RWKV
25
- from rwkvstic.agnostic.backends import TORCH, TORCH_QUANT, TORCH_STREAM
26
  import torch
27
  import gc
28
 
@@ -33,25 +33,25 @@ desc = '''<p>RNN with Transformer-level LLM Performance (<a href='https://github
33
 
34
  thanks = '''<p>Thanks to <a href='https://www.rftcapital.com'>RFT Capital</a> for donating compute capability for our experiments. Additional thanks to the author of the <a href="https://github.com/harrisonvanderbyl/rwkvstic">rwkvstic</a> library.</p>'''
35
 
 
36
  def to_md(text):
37
  return text.replace("\n", "<br />")
38
 
 
39
  def get_model():
40
  model = None
41
  model = RWKV(
42
- "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
43
- "pytorch(cpu/gpu)",
44
- runtimedtype=torch.float32,
45
- useGPU=torch.cuda.is_available(),
46
- dtype=torch.float32
47
  )
48
  return model
49
 
 
50
  model = None
51
 
 
52
  def infer(
53
  prompt,
54
- mode = "generative",
55
  max_new_tokens=10,
56
  temperature=0.1,
57
  top_p=1.0,
@@ -65,18 +65,18 @@ def infer(
65
  if (DEVICE == "cuda"):
66
  torch.cuda.empty_cache()
67
  model = get_model()
68
-
69
  max_new_tokens = int(max_new_tokens)
70
  temperature = float(temperature)
71
  top_p = float(top_p)
72
- stop = [x.strip(' ') for x in stop.split(',')]
73
  seed = seed
74
 
75
  assert 1 <= max_new_tokens <= 384
76
  assert 0.0 <= temperature <= 1.0
77
  assert 0.0 <= top_p <= 1.0
78
 
79
- temperature = max(0.05,temperature)
80
  if prompt == "":
81
  prompt = " "
82
 
@@ -84,7 +84,7 @@ def infer(
84
  model.resetState()
85
  if (mode == "Q/A"):
86
  prompt = f"Ask Expert\n\nQuestion:\n{prompt}\n\nExpert Full Answer:\n"
87
-
88
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
89
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
90
  # Load prompt
@@ -93,11 +93,12 @@ def infer(
93
  done = False
94
  with torch.no_grad():
95
  for _ in range(max_new_tokens):
96
- char = model.forward(stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
 
97
  print(char, end='', flush=True)
98
  generated_text += char
99
  generated_text = generated_text.lstrip("\n ")
100
-
101
  for stop_word in stop:
102
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
103
  if stop_word != '' and stop_word in generated_text:
@@ -108,13 +109,13 @@ def infer(
108
  print("<stopped>\n")
109
  break
110
 
111
- #print(f"{generated_text}")
112
-
113
  for stop_word in stop:
114
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
115
  if stop_word != '' and stop_word in generated_text:
116
  generated_text = generated_text[:generated_text.find(stop_word)]
117
-
118
  gc.collect()
119
  yield generated_text
120
 
@@ -130,9 +131,9 @@ def chat(
130
  ):
131
  global model
132
  history = history or []
133
-
134
  intro = ""
135
-
136
  if model == None:
137
  gc.collect()
138
  if (DEVICE == "cuda"):
@@ -141,7 +142,7 @@ def chat(
141
 
142
  username = username.strip()
143
  username = username or "USER"
144
-
145
  intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
146
 
147
  {username}: What year was the french revolution?
@@ -156,23 +157,22 @@ def chat(
156
  FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
157
  {username}: Tell me about yourself.
158
  FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
159
- '''
160
-
161
  if len(history) == 0:
162
  # no history, so lets reset chat state
163
  model.resetState()
164
- history = [[],model.emptyState]
165
  print("reset chat state")
166
  else:
167
  if (history[0][0][0].split(':')[0] != username):
168
  model.resetState()
169
- history = [[],model.emptyState]
170
  print("username changed, reset state")
171
  else:
172
  model.setState(history[1])
173
  intro = ""
174
-
175
-
176
  max_new_tokens = int(max_new_tokens)
177
  temperature = float(temperature)
178
  top_p = float(top_p)
@@ -182,16 +182,17 @@ def chat(
182
  assert 0.0 <= temperature <= 1.0
183
  assert 0.0 <= top_p <= 1.0
184
 
185
- temperature = max(0.05,temperature)
186
 
187
- prompt = f"{username}: " + prompt + "\n"
188
  print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
189
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
190
  # Load prompt
191
 
192
  model.loadContext(newctx=intro+prompt)
193
 
194
- out = model.forward(number=max_new_tokens, stopStrings=["<|endoftext|>",username+":"],temp=temperature,top_p_usual=top_p)
 
195
 
196
  generated_text = out["output"].lstrip("\n ")
197
  generated_text = generated_text.rstrip("USER:")
@@ -199,19 +200,19 @@ def chat(
199
 
200
  gc.collect()
201
  history[0].append((prompt, generated_text))
202
- return history[0],[history[0],out["state"]]
203
 
204
 
205
  examples = [
206
  [
207
  # Question Answering
208
- '''What is the capital of Germany?''',"Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
209
  [
210
  # Question Answering
211
- '''Are humans good or bad?''',"Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
212
  [
213
  # Question Answering
214
- '''What is the purpose of Vitamin A?''',"Q/A", 50, 0.2, 0.8, "<|endoftext|>"],
215
  [
216
  # Chatbot
217
  '''This is a conversation between two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
@@ -231,7 +232,7 @@ Best Full Response:
231
  [
232
  # Natural Language Interface
233
  '''Here is a short story (in the style of Tolkien) in which Aiden attacks a robot with a sword:
234
- ''',"generative", 140, 0.85, 0.8, "<|endoftext|>"]
235
  ]
236
 
237
 
@@ -241,11 +242,12 @@ iface = gr.Interface(
241
  allow_flagging="never",
242
  inputs=[
243
  gr.Textbox(lines=20, label="Prompt"), # prompt
244
- gr.Radio(["generative","Q/A"], value="generative", label="Choose Mode"),
 
245
  gr.Slider(1, 256, value=40), # max_tokens
246
  gr.Slider(0.0, 1.0, value=0.8), # temperature
247
  gr.Slider(0.0, 1.0, value=0.85), # top_p
248
- gr.Textbox(lines=1, value="<|endoftext|>") # stop
249
  ],
250
  outputs=gr.Textbox(label="Generated Output", lines=25),
251
  examples=examples,
@@ -259,20 +261,22 @@ chatiface = gr.Interface(
259
  inputs=[
260
  gr.Textbox(lines=5, label="Message"), # prompt
261
  "state",
262
- gr.Text(lines=1, value="USER", label="Your Name", placeholder="Enter your Name"),
 
263
  gr.Slider(1, 256, value=60), # max_tokens
264
  gr.Slider(0.0, 1.0, value=0.8), # temperature
265
  gr.Slider(0.0, 1.0, value=0.85) # top_p
266
  ],
267
- outputs=[gr.Chatbot(label="Chat Log", color_map=("green", "pink")),"state"],
 
268
  ).queue()
269
 
270
  demo = gr.TabbedInterface(
271
 
272
- [iface,chatiface],["Generative","Chatbot"],
273
- title="RWKV-4 (1.5b Instruct)",
274
-
275
- )
276
 
277
  demo.queue()
278
  demo.launch(share=False)
 
22
  from ast import literal_eval
23
  from datetime import datetime
24
  from rwkvstic.load import RWKV
25
+ from config import config, title
26
  import torch
27
  import gc
28
 
 
33
 
34
  thanks = '''<p>Thanks to <a href='https://www.rftcapital.com'>RFT Capital</a> for donating compute capability for our experiments. Additional thanks to the author of the <a href="https://github.com/harrisonvanderbyl/rwkvstic">rwkvstic</a> library.</p>'''
35
 
36
+
37
  def to_md(text):
38
  return text.replace("\n", "<br />")
39
 
40
+
41
  def get_model():
42
  model = None
43
  model = RWKV(
44
+ **config
 
 
 
 
45
  )
46
  return model
47
 
48
+
49
  model = None
50
 
51
+
52
  def infer(
53
  prompt,
54
+ mode="generative",
55
  max_new_tokens=10,
56
  temperature=0.1,
57
  top_p=1.0,
 
65
  if (DEVICE == "cuda"):
66
  torch.cuda.empty_cache()
67
  model = get_model()
68
+
69
  max_new_tokens = int(max_new_tokens)
70
  temperature = float(temperature)
71
  top_p = float(top_p)
72
+ stop = [x.strip(' ') for x in stop.split(',')]
73
  seed = seed
74
 
75
  assert 1 <= max_new_tokens <= 384
76
  assert 0.0 <= temperature <= 1.0
77
  assert 0.0 <= top_p <= 1.0
78
 
79
+ temperature = max(0.05, temperature)
80
  if prompt == "":
81
  prompt = " "
82
 
 
84
  model.resetState()
85
  if (mode == "Q/A"):
86
  prompt = f"Ask Expert\n\nQuestion:\n{prompt}\n\nExpert Full Answer:\n"
87
+
88
  print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
89
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
90
  # Load prompt
 
93
  done = False
94
  with torch.no_grad():
95
  for _ in range(max_new_tokens):
96
+ char = model.forward(stopStrings=stop, temp=temperature, top_p_usual=top_p)[
97
+ "output"]
98
  print(char, end='', flush=True)
99
  generated_text += char
100
  generated_text = generated_text.lstrip("\n ")
101
+
102
  for stop_word in stop:
103
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
104
  if stop_word != '' and stop_word in generated_text:
 
109
  print("<stopped>\n")
110
  break
111
 
112
+ # print(f"{generated_text}")
113
+
114
  for stop_word in stop:
115
  stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
116
  if stop_word != '' and stop_word in generated_text:
117
  generated_text = generated_text[:generated_text.find(stop_word)]
118
+
119
  gc.collect()
120
  yield generated_text
121
 
 
131
  ):
132
  global model
133
  history = history or []
134
+
135
  intro = ""
136
+
137
  if model == None:
138
  gc.collect()
139
  if (DEVICE == "cuda"):
 
142
 
143
  username = username.strip()
144
  username = username or "USER"
145
+
146
  intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
147
 
148
  {username}: What year was the french revolution?
 
157
  FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
158
  {username}: Tell me about yourself.
159
  FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
160
+ '''
161
+
162
  if len(history) == 0:
163
  # no history, so lets reset chat state
164
  model.resetState()
165
+ history = [[], model.emptyState]
166
  print("reset chat state")
167
  else:
168
  if (history[0][0][0].split(':')[0] != username):
169
  model.resetState()
170
+ history = [[], model.emptyState]
171
  print("username changed, reset state")
172
  else:
173
  model.setState(history[1])
174
  intro = ""
175
+
 
176
  max_new_tokens = int(max_new_tokens)
177
  temperature = float(temperature)
178
  top_p = float(top_p)
 
182
  assert 0.0 <= temperature <= 1.0
183
  assert 0.0 <= top_p <= 1.0
184
 
185
+ temperature = max(0.05, temperature)
186
 
187
+ prompt = f"{username}: " + prompt + "\n"
188
  print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
189
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
190
  # Load prompt
191
 
192
  model.loadContext(newctx=intro+prompt)
193
 
194
+ out = model.forward(number=max_new_tokens, stopStrings=[
195
+ "<|endoftext|>", username+":"], temp=temperature, top_p_usual=top_p)
196
 
197
  generated_text = out["output"].lstrip("\n ")
198
  generated_text = generated_text.rstrip("USER:")
 
200
 
201
  gc.collect()
202
  history[0].append((prompt, generated_text))
203
+ return history[0], [history[0], out["state"]]
204
 
205
 
206
  examples = [
207
  [
208
  # Question Answering
209
+ '''What is the capital of Germany?''', "Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
210
  [
211
  # Question Answering
212
+ '''Are humans good or bad?''', "Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
213
  [
214
  # Question Answering
215
+ '''What is the purpose of Vitamin A?''', "Q/A", 50, 0.2, 0.8, "<|endoftext|>"],
216
  [
217
  # Chatbot
218
  '''This is a conversation between two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
 
232
  [
233
  # Natural Language Interface
234
  '''Here is a short story (in the style of Tolkien) in which Aiden attacks a robot with a sword:
235
+ ''', "generative", 140, 0.85, 0.8, "<|endoftext|>"]
236
  ]
237
 
238
 
 
242
  allow_flagging="never",
243
  inputs=[
244
  gr.Textbox(lines=20, label="Prompt"), # prompt
245
+ gr.Radio(["generative", "Q/A"],
246
+ value="generative", label="Choose Mode"),
247
  gr.Slider(1, 256, value=40), # max_tokens
248
  gr.Slider(0.0, 1.0, value=0.8), # temperature
249
  gr.Slider(0.0, 1.0, value=0.85), # top_p
250
+ gr.Textbox(lines=1, value="<|endoftext|>") # stop
251
  ],
252
  outputs=gr.Textbox(label="Generated Output", lines=25),
253
  examples=examples,
 
261
  inputs=[
262
  gr.Textbox(lines=5, label="Message"), # prompt
263
  "state",
264
+ gr.Text(lines=1, value="USER", label="Your Name",
265
+ placeholder="Enter your Name"),
266
  gr.Slider(1, 256, value=60), # max_tokens
267
  gr.Slider(0.0, 1.0, value=0.8), # temperature
268
  gr.Slider(0.0, 1.0, value=0.85) # top_p
269
  ],
270
+ outputs=[gr.Chatbot(label="Chat Log", color_map=(
271
+ "green", "pink")), "state"],
272
  ).queue()
273
 
274
  demo = gr.TabbedInterface(
275
 
276
+ [iface, chatiface], ["Generative", "Chatbot"],
277
+ title=title,
278
+
279
+ )
280
 
281
  demo.queue()
282
  demo.launch(share=False)
config.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rwkvstic.agnostic.backends import TORCH, TORCH_QUANT
2
+ import torch
3
+
4
+ quantized = {
5
+ "mode": TORCH_QUANT,
6
+ "runtimedtype": torch.bfloat16,
7
+ "useGPU": torch.cuda.is_available(),
8
+ "chunksize": 32, # larger = more accurate, but more memory
9
+ "target": 100 # your gpu max size, excess vram offloaded to cpu
10
+ }
11
+
12
+ # UNCOMMENT TO SELECT OPTIONS
13
+ # Not full list of options, see https://pypi.org/project/rwkvstic/ and https://huggingface.co/BlinkDL/ for more models/modes
14
+
15
+ # RWKV 1B5 instruct test 1 model
16
+ # Approximate
17
+ # [Vram usage: 6.0GB]
18
+ # [File size: 3.0GB]
19
+
20
+
21
+ config = {
22
+ "path": "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
23
+ "mode": TORCH,
24
+ "runtimedtype": torch.float32,
25
+ "useGPU": torch.cuda.is_available(),
26
+ "dtype": torch.float32
27
+ }
28
+
29
+ title = "RWKV-4 (1.5b Instruct)"
30
+
31
+ # RWKV 1B5 instruct model quantized
32
+ # Approximate
33
+ # [Vram usage: 1.3GB]
34
+ # [File size: 3.0GB]
35
+
36
+ # config = {
37
+ # "path": "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
38
+ # **quantized
39
+ # }
40
+
41
+ # title = "RWKV-4 (1.5b Instruct Quantized)"
42
+
43
+ # RWKV 7B instruct pre-quantized (settings baked into model)
44
+ # Approximate
45
+ # [Vram usage: 7.0GB]
46
+ # [File size: 8.0GB]
47
+
48
+ # config = {
49
+ # "path": "https://huggingface.co/Hazzzardous/RWKV-8Bit/resolve/main/RWKV-4-Pile-7B-Instruct.pqth"
50
+ # }
51
+
52
+ # title = "RWKV-4 (7b Instruct Quantized)"
53
+
54
+ # RWKV 14B quantized (latest as of feb 9)
55
+ # Approximate
56
+ # [Vram usage: 15.0GB]
57
+ # [File size: 28.0GB]
58
+
59
+ # config = {
60
+ # "path": "https://huggingface.co/BlinkDL/rwkv-4-pile-14b/resolve/main/RWKV-4-Pile-14B-20230204-7324.pth"
61
+ # }
62
+
63
+ # title = "RWKV-4 (14b)"