Hazzzardous commited on
Commit
6b466a0
1 Parent(s): 7bc916c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -31
app.py CHANGED
@@ -1,38 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
1
  from rwkvstic.load import RWKV
 
2
  import torch
3
- model = RWKV(
4
- "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
5
- "pytorch(cpu/gpu)",
6
- runtimedtype=torch.float32,
7
- useGPU=torch.cuda.is_available(),
8
- dtype=torch.float32
9
- )
10
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def predict(input, history=None):
14
- model.setState(history[1])
15
- model.loadContext(newctx=f"Prompt: {input}\n\nExpert Long Detailed Response: ")
16
- r = model.forward(number=100,stopStrings=["\n\nPrompt"])
17
- rr = [(input,r["output"])]
18
- return [*history[0],*rr], [[*history[0],*rr],r["state"]]
 
 
 
 
 
19
 
20
- def freegen(input):
 
 
 
 
 
 
 
 
 
21
  model.resetState()
22
- model.loadContext(newctx=input)
23
- return model.forward(number=100)["output"]
24
- with gr.Blocks() as demo:
25
- with gr.Tab("Chatbot"):
26
- chatbot = gr.Chatbot()
27
- state = model.emptyState
28
- state = gr.State([[],state])
29
- with gr.Row():
30
- txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- txt.submit(predict, [txt, state], [chatbot, state])
33
- with gr.Tab("Free Gen"):
34
- with gr.Row():
35
- input = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
36
- outtext = gr.Textbox(show_label=False)
37
- input.submit(freegen,input,outtext)
38
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RWKV RNN Model - Gradio Space for HuggingFace
3
+ YT - Mean Gene Hacks - https://www.youtube.com/@MeanGeneHacks
4
+ (C) Gene Ruebsamen - 2/7/2023
5
+ License: GPL3
6
+ """
7
+
8
+ import gradio as gr
9
+ import codecs
10
+ from ast import literal_eval
11
+ from datetime import datetime
12
  from rwkvstic.load import RWKV
13
+ from rwkvstic.agnostic.backends import TORCH, TORCH_QUANT, TORCH_STREAM
14
  import torch
15
+ import gc
16
+
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ def to_md(text):
20
+ return text.replace("\n", "<br />")
21
+
22
+
23
+ def get_model():
24
+ model = None
25
+ model = RWKV(
26
+ "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
27
+ "pytorch(cpu/gpu)",
28
+ runtimedtype=torch.float32,
29
+ useGPU=torch.cuda.is_available(),
30
+ dtype=torch.float32
31
+ )
32
+ return model
33
+
34
+ model = None
35
 
36
+ def infer(
37
+ prompt,
38
+ mode = "generative",
39
+ max_new_tokens=10,
40
+ temperature=0.1,
41
+ top_p=1.0,
42
+ stop="<|endoftext|>",
43
+ seed=42,
44
+ ):
45
+ global model
46
 
47
+ if model == None:
48
+ gc.collect()
49
+ if (DEVICE == "cuda"):
50
+ torch.cuda.empty_cache()
51
+ model = get_model()
52
+
53
+ max_new_tokens = int(max_new_tokens)
54
+ temperature = float(temperature)
55
+ top_p = float(top_p)
56
+ stop = [x.strip(' ') for x in stop.split(',')]
57
+ seed = seed
58
 
59
+ assert 1 <= max_new_tokens <= 384
60
+ assert 0.0 <= temperature <= 1.0
61
+ assert 0.0 <= top_p <= 1.0
62
+
63
+ if temperature == 0.0:
64
+ temperature = 0.01
65
+ if prompt == "":
66
+ prompt = " "
67
+
68
+ # Clear model state for generative mode
69
  model.resetState()
70
+ if (mode == "Q/A"):
71
+ prompt = f"Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\n{prompt}\n\nFull Answer:"
72
+
73
+ print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}")
74
+ print(f"OUTPUT ({datetime.now()}):\n-------\n")
75
+ # Load prompt
76
+ model.loadContext(newctx=prompt)
77
+ generated_text = ""
78
+ done = False
79
+ with torch.no_grad():
80
+ for _ in range(max_new_tokens):
81
+ char = model.forward(stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
82
+ print(char, end='', flush=True)
83
+ generated_text += char
84
+ generated_text = generated_text.lstrip("\n ")
85
+
86
+ for stop_word in stop:
87
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
88
+ if stop_word != '' and stop_word in generated_text:
89
+ done = True
90
+ break
91
+ yield generated_text
92
+ if done:
93
+ print("<stopped>\n")
94
+ break
95
+
96
+ #print(f"{generated_text}")
97
+
98
+ for stop_word in stop:
99
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
100
+ if stop_word != '' and stop_word in generated_text:
101
+ generated_text = generated_text[:generated_text.find(stop_word)]
102
+
103
+ gc.collect()
104
+ yield generated_text
105
+
106
+
107
+ def chat(
108
+ prompt,
109
+ history,
110
+ max_new_tokens=10,
111
+ temperature=0.1,
112
+ top_p=1.0,
113
+ stop="<|endoftext|>",
114
+ seed=42,
115
+ ):
116
+ global model
117
+ history = history or []
118
+
119
+ if model == None:
120
+ gc.collect()
121
+ if (DEVICE == "cuda"):
122
+ torch.cuda.empty_cache()
123
+ model = get_model()
124
+
125
+ if len(history) == 0:
126
+ # no history, so lets reset chat state
127
+ model.resetState()
128
+
129
+ max_new_tokens = int(max_new_tokens)
130
+ temperature = float(temperature)
131
+ top_p = float(top_p)
132
+ stop = [x.strip(' ') for x in stop.split(',')]
133
+ seed = seed
134
+
135
+ assert 1 <= max_new_tokens <= 384
136
+ assert 0.0 <= temperature <= 1.0
137
+ assert 0.0 <= top_p <= 1.0
138
+
139
+ if temperature == 0.0:
140
+ temperature = 0.01
141
+ if prompt == "":
142
+ prompt = " "
143
+
144
+ print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
145
+ print(f"OUTPUT ({datetime.now()}):\n-------\n")
146
+ # Load prompt
147
+ model.loadContext(newctx=prompt)
148
+ generated_text = ""
149
+ done = False
150
+ generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
151
+
152
+ generated_text = generated_text.lstrip("\n ")
153
+ print(f"{generated_text}")
154
+
155
+ for stop_word in stop:
156
+ stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
157
+ if stop_word != '' and stop_word in generated_text:
158
+ generated_text = generated_text[:generated_text.find(stop_word)]
159
 
160
+ gc.collect()
161
+ history.append((prompt, generated_text))
162
+ return history,history
163
+
164
+
165
+ examples = [
166
+ [
167
+ # Question Answering
168
+ '''What is the capital of Germany?''',"Q/A", 25, 0.2, 1.0, "<|endoftext|>"],
169
+ [
170
+ # Question Answering
171
+ '''Are humans good or bad?''',"Q/A", 150, 0.8, 0.8, "<|endoftext|>"],
172
+ [
173
+ # Chatbot
174
+ '''This is a conversation between two AI large language models named Alex and Fritz. They are exploring each other's capabilities, and trying to ask interesting questions of one another to explore the limits of each others AI.
175
+ Conversation:
176
+ Alex: Good morning, Fritz, what type of LLM are you based upon?
177
+ Fritz: Morning Alex, I am an RNN with transformer level performance. My language model is 100% attention free.
178
+ Alex:''', "generative", 220, 0.9, 0.9, "\\n\\n,<|endoftext|>"],
179
+ [
180
+ # Generate List
181
+ '''Q. Give me list of fiction books.
182
+ 1. Harry Potter
183
+ 2. Lord of the Rings
184
+ 3. Game of Thrones
185
+ Q. Give me a list of vegetables.
186
+ 1. Broccoli
187
+ 2. Celery
188
+ 3. Tomatoes
189
+ Q. Give me a list of car manufacturers.''', "generative", 80, 0.2, 1.0, "\\n\\n,<|endoftext|>"],
190
+ [
191
+ # Natural Language Interface
192
+ '''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest horror novel, set to be completed in the spring of 2024. Create a summary of this work.
193
+ Summary:''',"generative", 200, 0.85, 0.8, "<|endoftext|>"]
194
+ ]
195
+
196
+
197
+ iface = gr.Interface(
198
+ fn=infer,
199
+ description='''<p>RNN With Transformer-level LLM Performance. (<a href='https://github.com/BlinkDL/RWKV-LM'>github</a>)
200
+ According to the author: "It combines the best of RNN and transformers - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding"
201
+ <p>Thanks to <a href='https://www.rftcapital.com'>RFT Capital</a> for donating compute capability for our experiments. Additional thanks to the author of the <a href="https://github.com/harrisonvanderbyl/rwkvstic">rwkvstic</a> library.</p>''',
202
+ allow_flagging="never",
203
+ inputs=[
204
+ gr.Textbox(lines=20, label="Prompt"), # prompt
205
+ gr.Radio(["generative","Q/A"], value="generative", label="Choose Mode"),
206
+ gr.Slider(1, 256, value=40), # max_tokens
207
+ gr.Slider(0.0, 1.0, value=0.8), # temperature
208
+ gr.Slider(0.0, 1.0, value=0.85), # top_p
209
+ gr.Textbox(lines=1, value="<|endoftext|>") # stop
210
+ ],
211
+ outputs=gr.Textbox(lines=25),
212
+ examples=examples,
213
+ cache_examples=False,
214
+ ).queue()
215
+
216
+ chatiface = gr.Interface(
217
+ fn=chat,
218
+ description='''<p>RNN With Transformer-level LLM Performance. (<a href='https://github.com/BlinkDL/RWKV-LM'>github</a>)
219
+ According to the author: "It combines the best of RNN and transformers - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding"
220
+ <p>Thanks to <a href='https://www.rftcapital.com'>RFT Capital</a> for donating compute capability for our experiments. Additional thanks to the author of the <a href="https://github.com/harrisonvanderbyl/rwkvstic">rwkvstic</a> library.</p>''',
221
+ allow_flagging="never",
222
+ inputs=[
223
+ gr.Textbox(lines=5, label="Message"), # prompt
224
+ "state",
225
+ gr.Slider(1, 256, value=60), # max_tokens
226
+ gr.Slider(0.0, 1.0, value=0.8), # temperature
227
+ gr.Slider(0.0, 1.0, value=0.85), # top_p
228
+ gr.Textbox(lines=1, value="<|endoftext|>") # stop
229
+ ],
230
+ outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
231
+ ).queue()
232
+
233
+ demo = gr.TabbedInterface(
234
+
235
+ [iface,chatiface],["Generative","Chatbot"],
236
+ title="RWKV-4 (1.5b Instruct)",
237
+
238
+ )
239
+
240
+ demo.queue()
241
+ demo.launch(share=False)