Pclanglais commited on
Commit
d6b6a6e
1 Parent(s): 50af2bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -34
app.py CHANGED
@@ -100,7 +100,19 @@ class StopOnTokens(StoppingCriteria):
100
  return False
101
 
102
 
103
- def predict(history_transformer_format):
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  print(history_transformer_format)
106
  stop = StopOnTokens()
@@ -129,8 +141,6 @@ def predict(history_transformer_format):
129
 
130
  messages = system_prompt + messages
131
 
132
- print(messages)
133
-
134
  model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
135
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
136
  generate_kwargs = dict(
@@ -145,27 +155,12 @@ def predict(history_transformer_format):
145
  t = Thread(target=model.generate, kwargs=generate_kwargs)
146
  t.start()
147
 
148
- history_transformer_format[-1][1] = ""
149
  for new_token in streamer:
150
  if new_token != '<':
151
- history_transformer_format[-1][1] += new_token
152
- yield history_transformer_format
153
-
154
- def user(message, history):
155
- global source_text
156
- global assess_rag
157
- #For now, we only query the vector database once, at the start.
158
- if len(history) == 0:
159
- assess_rag = classification_chatrag(message)
160
- if assess_rag:
161
- source_text = vector_search(message)
162
- else:
163
- source_text = "Albert-Tchap n'utilise pas de sources comme votre requête n'a pas l'air d'en recueillir."
164
-
165
- history_transformer_format = history + [[message, ""]]
166
-
167
- print(history_transformer_format)
168
- return "", history_transformer_format
169
 
170
  # Define the Gradio interface
171
  title = "Tchap"
@@ -176,17 +171,9 @@ examples = [
176
  0.7 # temperature
177
  ]
178
  ]
179
-
180
  with gr.Blocks() as demo:
181
- chatbot = gr.Chatbot()
182
- msg = gr.Textbox()
183
- clear = gr.Button("Clear")
184
-
185
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
186
- predict, chatbot, chatbot
187
- )
188
- clear.click(lambda: None, None, chatbot, queue=False)
189
-
190
 
191
- demo.queue()
192
- demo.launch()
 
100
  return False
101
 
102
 
103
+ def predict(message, history):
104
+
105
+ global source_text
106
+ global assess_rag
107
+ #For now, we only query the vector database once, at the start.
108
+ if len(history) == 0:
109
+ assess_rag = classification_chatrag(message)
110
+ if assess_rag:
111
+ source_text = vector_search(message)
112
+ else:
113
+ source_text = "Albert-Tchap n'utilise pas de sources comme votre requête n'a pas l'air d'en recueillir."
114
+
115
+ history_transformer_format = history + [[message, ""]]
116
 
117
  print(history_transformer_format)
118
  stop = StopOnTokens()
 
141
 
142
  messages = system_prompt + messages
143
 
 
 
144
  model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
145
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
146
  generate_kwargs = dict(
 
155
  t = Thread(target=model.generate, kwargs=generate_kwargs)
156
  t.start()
157
 
158
+ partial_message = ""
159
  for new_token in streamer:
160
  if new_token != '<':
161
+ partial_message += new_token
162
+ yield partial_message
163
+ return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # Define the Gradio interface
166
  title = "Tchap"
 
171
  0.7 # temperature
172
  ]
173
  ]
174
+ demo = gr.Blocks()
175
  with gr.Blocks() as demo:
176
+ gr.ChatInterface(predict)
 
 
 
 
 
 
 
 
177
 
178
+ if __name__ == "__main__":
179
+ demo.queue().launch()