Hazzzardous commited on
Commit
0791700
1 Parent(s): 5667dbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -31
app.py CHANGED
@@ -46,7 +46,8 @@ def get_model():
46
  return model
47
 
48
 
49
- model = None
 
50
 
51
 
52
  def infer(
@@ -126,12 +127,29 @@ def infer(
126
 
127
  gc.collect()
128
  yield generated_text
 
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
 
 
 
131
  def chat(
132
  prompt,
133
  history,
134
- username,
135
  max_new_tokens=10,
136
  temperature=0.1,
137
  top_p=1.0,
@@ -151,31 +169,17 @@ def chat(
151
  username = username.strip()
152
  username = username or "USER"
153
 
154
- intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
155
-
156
- {username}: What year was the french revolution?
157
- FRITZ: The French Revolution started in 1789, and lasted 10 years until 1799.
158
- {username}: 3+5=?
159
- FRITZ: The answer is 8.
160
- {username}: What year did the Berlin Wall fall?
161
- FRITZ: The Berlin wall stood for 28 years and fell in 1989.
162
- {username}: solve for a: 9-a=2
163
- FRITZ: The answer is a=7, because 9-7 = 2.
164
- {username}: wat is lhc
165
- FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
166
- {username}: Tell me about yourself.
167
- FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
168
- '''
169
 
170
  if len(history) == 0:
171
  # no history, so lets reset chat state
172
- model.resetState()
173
  history = [[], model.emptyState]
174
  print("reset chat state")
175
  else:
176
  if (history[0][0][0].split(':')[0] != username):
177
- model.resetState()
178
- history = [[], model.emptyState]
179
  print("username changed, reset state")
180
  else:
181
  model.setState(history[1])
@@ -186,8 +190,8 @@ def chat(
186
  top_p = float(top_p)
187
  seed = seed
188
 
189
- assert 1 <= max_new_tokens <= 384
190
- assert 0.0 <= temperature <= 1.0
191
  assert 0.0 <= top_p <= 1.0
192
 
193
  temperature = max(0.05, temperature)
@@ -197,13 +201,13 @@ def chat(
197
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
198
  # Load prompt
199
 
200
- model.loadContext(newctx=intro+prompt)
201
 
202
  out = model.forward(number=max_new_tokens, stopStrings=[
203
  "<|endoftext|>", username+":"], temp=temperature, top_p_usual=top_p)
204
 
205
  generated_text = out["output"].lstrip("\n ")
206
- generated_text = generated_text.rstrip("USER:")
207
  print(f"{generated_text}")
208
 
209
  gc.collect()
@@ -251,12 +255,13 @@ iface = gr.Interface(
251
  inputs=[
252
  gr.Textbox(lines=20, label="Prompt"), # prompt
253
  gr.Radio(["generative", "Q/A","ELDR","EFR","BFR"],
254
- value="generative", label="Choose Mode"),
255
- gr.Slider(1, 256, value=40), # max_tokens
256
- gr.Slider(0.0, 1.0, value=0.8), # temperature
257
  gr.Slider(0.0, 1.0, value=0.85), # top_p
 
258
  gr.Slider(-999, 0.0, value=0.0), # end_adj
259
- gr.Textbox(lines=1, value="<|endoftext|>") # stop
260
  ],
261
  outputs=gr.Textbox(label="Generated Output", lines=25),
262
  examples=examples,
@@ -270,8 +275,6 @@ chatiface = gr.Interface(
270
  inputs=[
271
  gr.Textbox(lines=5, label="Message"), # prompt
272
  "state",
273
- gr.Text(lines=1, value="USER", label="Your Name",
274
- placeholder="Enter your Name"),
275
  gr.Slider(1, 256, value=60), # max_tokens
276
  gr.Slider(0.0, 1.0, value=0.8), # temperature
277
  gr.Slider(0.0, 1.0, value=0.85) # top_p
@@ -282,7 +285,7 @@ chatiface = gr.Interface(
282
 
283
  demo = gr.TabbedInterface(
284
 
285
- [iface, chatiface], ["Generative", "Chatbot"],
286
  title=title,
287
 
288
  )
 
46
  return model
47
 
48
 
49
+ model = get_model()
50
+
51
 
52
 
53
  def infer(
 
127
 
128
  gc.collect()
129
  yield generated_text
130
+ username = "USER"
131
+ intro = f'''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
132
 
133
+ {username}: What year was the french revolution?
134
+ FRITZ: The French Revolution started in 1789, and lasted 10 years until 1799.
135
+ {username}: 3+5=?
136
+ FRITZ: The answer is 8.
137
+ {username}: What year did the Berlin Wall fall?
138
+ FRITZ: The Berlin wall stood for 28 years and fell in 1989.
139
+ {username}: solve for a: 9-a=2
140
+ FRITZ: The answer is a=7, because 9-7 = 2.
141
+ {username}: wat is lhc
142
+ FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
143
+ {username}: Tell me about yourself.
144
+ FRITZ: My name is Fritz. I am an RNN based Large Language Model (LLM).
145
+ '''
146
 
147
+ model.loadContext(intro)
148
+ chatState = model.getState().clone()
149
+ model.resetState()
150
  def chat(
151
  prompt,
152
  history,
 
153
  max_new_tokens=10,
154
  temperature=0.1,
155
  top_p=1.0,
 
169
  username = username.strip()
170
  username = username or "USER"
171
 
172
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  if len(history) == 0:
175
  # no history, so lets reset chat state
176
+ model.setState(chatState)
177
  history = [[], model.emptyState]
178
  print("reset chat state")
179
  else:
180
  if (history[0][0][0].split(':')[0] != username):
181
+ model.setState(chatState)
182
+ history = [[], model.chatState]
183
  print("username changed, reset state")
184
  else:
185
  model.setState(history[1])
 
190
  top_p = float(top_p)
191
  seed = seed
192
 
193
+ assert 1 <= max_new_tokens <= 512
194
+ assert 0.0 <= temperature <= 3.0
195
  assert 0.0 <= top_p <= 1.0
196
 
197
  temperature = max(0.05, temperature)
 
201
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
202
  # Load prompt
203
 
204
+ model.loadContext(newctx=prompt)
205
 
206
  out = model.forward(number=max_new_tokens, stopStrings=[
207
  "<|endoftext|>", username+":"], temp=temperature, top_p_usual=top_p)
208
 
209
  generated_text = out["output"].lstrip("\n ")
210
+ generated_text = generated_text.rstrip(username+":")
211
  print(f"{generated_text}")
212
 
213
  gc.collect()
 
255
  inputs=[
256
  gr.Textbox(lines=20, label="Prompt"), # prompt
257
  gr.Radio(["generative", "Q/A","ELDR","EFR","BFR"],
258
+ value="ELDR", label="Choose Mode"),
259
+ gr.Slider(1, 512, value=40), # max_tokens
260
+ gr.Slider(0.0, 5.0, value=1.0), # temperature
261
  gr.Slider(0.0, 1.0, value=0.85), # top_p
262
+ gr.Textbox(lines=1, value="<|endoftext|>"), # stop
263
  gr.Slider(-999, 0.0, value=0.0), # end_adj
264
+
265
  ],
266
  outputs=gr.Textbox(label="Generated Output", lines=25),
267
  examples=examples,
 
275
  inputs=[
276
  gr.Textbox(lines=5, label="Message"), # prompt
277
  "state",
 
 
278
  gr.Slider(1, 256, value=60), # max_tokens
279
  gr.Slider(0.0, 1.0, value=0.8), # temperature
280
  gr.Slider(0.0, 1.0, value=0.85) # top_p
 
285
 
286
  demo = gr.TabbedInterface(
287
 
288
+ [iface, chatiface], ["ELDR", "Chatbot"],
289
  title=title,
290
 
291
  )