alexkueck commited on
Commit
9e85ff2
·
1 Parent(s): 3054dce

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +58 -0
utils.py CHANGED
@@ -150,6 +150,64 @@ def greedy_search(input_ids: torch.Tensor,
150
  gc.collect()
151
  return
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def convert_to_markdown(text):
154
  text = text.replace("$","$")
155
  def replace_leading_tabs_and_spaces(line):
 
150
  gc.collect()
151
  return
152
 
153
+ ########################################
154
+ #Predict
155
+ def predict(text,
156
+ history,
157
+ top_p,
158
+ temperature,
159
+ max_length_tokens,
160
+ max_context_length_tokens,):
161
+ if text=="":
162
+ yield history,"Empty context."
163
+ return
164
+ try:
165
+ model
166
+ except:
167
+ yield [[text,"No Model Found"]],[],"No Model Found"
168
+ return
169
+
170
+ inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
171
+ if inputs is None:
172
+ yield history,"Input too long."
173
+ return
174
+ else:
175
+ prompt,inputs=inputs
176
+ begin_length = len(prompt)
177
+
178
+ input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
179
+ torch.cuda.empty_cache()
180
+
181
+ #torch.no_grad() bedeutet, dass für die betreffenden tensoren keine Ableitungen berechnet werden bei der backpropagation
182
+ #hier soll das NN ja auch nicht geändert werden 8backprop ist nicht nötig), da es um interference-prompts geht!
183
+ with torch.no_grad():
184
+ #die vergangenen prompts werden alle als Tupel in history abgelegt sortiert nach 'Human' und 'AI'- dass sind daher auch die stop-words, die den jeweils nächsten Eintrag kennzeichnen
185
+ for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
186
+ if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
187
+ if "[|Human|]" in x:
188
+ x = x[:x.index("[|Human|]")].strip()
189
+ if "[|AI|]" in x:
190
+ x = x[:x.index("[|AI|]")].strip()
191
+ x = x.strip()
192
+ a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
193
+ yield a, b, "Generating..."
194
+ if shared_state.interrupted:
195
+ shared_state.recover()
196
+ try:
197
+ yield a, b, "Stop: Success"
198
+ return
199
+ except:
200
+ pass
201
+ del input_ids
202
+ gc.collect()
203
+ torch.cuda.empty_cache()
204
+
205
+ try:
206
+ yield a,b,"Generate: Success"
207
+ except:
208
+ pass
209
+
210
+
211
  def convert_to_markdown(text):
212
  text = text.replace("$","$")
213
  def replace_leading_tabs_and_spaces(line):