Hazzzardous commited on
Commit
b644119
1 Parent(s): 92fcbd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -23,7 +23,11 @@ def to_md(text):
23
  def get_model():
24
  model = None
25
  model = RWKV(
26
- "https://huggingface.co/Hazzzardous/RWKV-8Bit/resolve/main/RWKV-4-Pile-7B-Instruct.pqth",
 
 
 
 
27
  )
28
  return model
29
 
@@ -118,10 +122,11 @@ def chat(
118
  torch.cuda.empty_cache()
119
  model = get_model()
120
 
121
- if len(history) == 0:
122
  # no history, so lets reset chat state
123
  model.resetState()
124
-
 
125
  max_new_tokens = int(max_new_tokens)
126
  temperature = float(temperature)
127
  top_p = float(top_p)
@@ -143,8 +148,8 @@ def chat(
143
  model.loadContext(newctx=prompt)
144
  generated_text = ""
145
  done = False
146
- generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
147
-
148
  generated_text = generated_text.lstrip("\n ")
149
  print(f"{generated_text}")
150
 
@@ -154,8 +159,8 @@ def chat(
154
  generated_text = generated_text[:generated_text.find(stop_word)]
155
 
156
  gc.collect()
157
- history.append((prompt, generated_text))
158
- return history,history
159
 
160
 
161
  examples = [
 
23
  def get_model():
24
  model = None
25
  model = RWKV(
26
+ "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
27
+ "pytorch(cpu/gpu)",
28
+ runtimedtype=torch.float32,
29
+ useGPU=torch.cuda.is_available(),
30
+ dtype=torch.float32
31
  )
32
  return model
33
 
 
122
  torch.cuda.empty_cache()
123
  model = get_model()
124
 
125
+ if len(history[0]) == 0:
126
  # no history, so lets reset chat state
127
  model.resetState()
128
+ else:
129
+ model.setState(history[1])
130
  max_new_tokens = int(max_new_tokens)
131
  temperature = float(temperature)
132
  top_p = float(top_p)
 
148
  model.loadContext(newctx=prompt)
149
  generated_text = ""
150
  done = False
151
+ gen = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)
152
+ generated_text = gen["output"]
153
  generated_text = generated_text.lstrip("\n ")
154
  print(f"{generated_text}")
155
 
 
159
  generated_text = generated_text[:generated_text.find(stop_word)]
160
 
161
  gc.collect()
162
+ history[0].append((prompt, generated_text))
163
+ return history[0],[history[0],gen["state"]]
164
 
165
 
166
  examples = [