alexkueck commited on
Commit
bde5bdd
1 Parent(s): 9ab420f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -7
app.py CHANGED
@@ -30,23 +30,32 @@ def predict(text,
30
  temperature,
31
  max_length_tokens,
32
  max_context_length_tokens,):
 
 
 
33
  if text=="":
34
- yield chatbotGr,history,"Empty context."
35
  return
 
 
36
  try:
37
  model
38
  except:
39
- yield [[text,"No Model Found"]],[],"No Model Found"
40
  return
41
 
 
42
  inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
43
  if inputs is None:
44
- yield chatbotGr,history,"Input too long."
45
  return
46
  else:
47
  prompt,inputs=inputs
48
  begin_length = len(prompt)
49
-
 
 
 
50
  input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
51
  torch.cuda.empty_cache()
52
 
@@ -78,8 +87,18 @@ def predict(text,
78
  yield a,b,"Generate: Success"
79
  except:
80
  pass
 
 
 
 
 
81
 
 
 
 
82
 
 
 
83
  def reset_chat():
84
  #id_new = chatbot.new_conversation()
85
  #chatbot.change_conversation(id_new)
@@ -162,15 +181,14 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
162
  predict_args = dict(
163
  fn=predict,
164
  inputs=[
165
- user_question,
166
- chatbotGr,
167
  history,
168
  top_p,
169
  temperature,
170
  max_length_tokens,
171
  max_context_length_tokens,
172
  ],
173
- outputs=[chatbotGr, history, status_display],
174
  show_progress=True,
175
  )
176
 
 
30
  temperature,
31
  max_length_tokens,
32
  max_context_length_tokens,):
33
+ global model, tokenizer, device
34
+
35
+ #wenn eingabe leer - nix tun
36
  if text=="":
37
+ yield history,"Empty context."
38
  return
39
+
40
+ #wenn Model nicht gefunden -> Fehler
41
  try:
42
  model
43
  except:
44
+ yield [],"No Model Found"
45
  return
46
 
47
+ #Prompt generieren -> mit Kontext bezogen auch auf vorhergehende Eingaben in dem chat
48
  inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
49
  if inputs is None:
50
+ yield history,"Input too long."
51
  return
52
  else:
53
  prompt,inputs=inputs
54
  begin_length = len(prompt)
55
+
56
+ #####################################################################################################
57
+ #ist glaube ich unnötig, da ich mit Pipeline arbeiten -> mal schauen, ich behalte es noch...
58
+ """
59
  input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
60
  torch.cuda.empty_cache()
61
 
 
87
  yield a,b,"Generate: Success"
88
  except:
89
  pass
90
+ """
91
+ ##########################################################################
92
+ #Prompt ist erzeugt, nun mit pipeline eine Antwort von der KI bekommen!
93
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
94
+ bot_message = pipe(prompt)
95
 
96
+ #chatbot - history erweitern und an chatbotGr zurückschicken
97
+ history = history.append((text, bot_message))
98
+ return history, "Erfolg!"
99
 
100
+
101
+ #neuen Chat beginnen
102
  def reset_chat():
103
  #id_new = chatbot.new_conversation()
104
  #chatbot.change_conversation(id_new)
 
181
  predict_args = dict(
182
  fn=predict,
183
  inputs=[
184
+ user_input,
 
185
  history,
186
  top_p,
187
  temperature,
188
  max_length_tokens,
189
  max_context_length_tokens,
190
  ],
191
+ outputs=[chatbotGr, status_display],
192
  show_progress=True,
193
  )
194