Pclanglais commited on
Commit
b053d03
1 Parent(s): d6b6a6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -21
app.py CHANGED
@@ -100,19 +100,7 @@ class StopOnTokens(StoppingCriteria):
100
  return False
101
 
102
 
103
- def predict(message, history):
104
-
105
- global source_text
106
- global assess_rag
107
- #For now, we only query the vector database once, at the start.
108
- if len(history) == 0:
109
- assess_rag = classification_chatrag(message)
110
- if assess_rag:
111
- source_text = vector_search(message)
112
- else:
113
- source_text = "Albert-Tchap n'utilise pas de sources comme votre requête n'a pas l'air d'en recueillir."
114
-
115
- history_transformer_format = history + [[message, ""]]
116
 
117
  print(history_transformer_format)
118
  stop = StopOnTokens()
@@ -141,6 +129,8 @@ def predict(message, history):
141
 
142
  messages = system_prompt + messages
143
 
 
 
144
  model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
145
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
146
  generate_kwargs = dict(
@@ -155,12 +145,27 @@ def predict(message, history):
155
  t = Thread(target=model.generate, kwargs=generate_kwargs)
156
  t.start()
157
 
158
- partial_message = ""
159
  for new_token in streamer:
160
  if new_token != '<':
161
- partial_message += new_token
162
- yield partial_message
163
- return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # Define the Gradio interface
166
  title = "Tchap"
@@ -171,9 +176,21 @@ examples = [
171
  0.7 # temperature
172
  ]
173
  ]
174
- demo = gr.Blocks()
175
  with gr.Blocks() as demo:
176
- gr.ChatInterface(predict)
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- if __name__ == "__main__":
179
- demo.queue().launch()
 
100
  return False
101
 
102
 
103
+ def predict(history_transformer_format):
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  print(history_transformer_format)
106
  stop = StopOnTokens()
 
129
 
130
  messages = system_prompt + messages
131
 
132
+ print(messages)
133
+
134
  model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
135
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
136
  generate_kwargs = dict(
 
145
  t = Thread(target=model.generate, kwargs=generate_kwargs)
146
  t.start()
147
 
148
+ history_transformer_format[-1][1] = ""
149
  for new_token in streamer:
150
  if new_token != '<':
151
+ history_transformer_format[-1][1] += new_token
152
+ yield history_transformer_format
153
+
154
+ def user(message, history):
155
+ global source_text
156
+ global assess_rag
157
+ #For now, we only query the vector database once, at the start.
158
+ if len(history) == 0:
159
+ assess_rag = classification_chatrag(message)
160
+ if assess_rag:
161
+ source_text = vector_search(message)
162
+ else:
163
+ source_text = "Albert-Tchap n'utilise pas de sources comme votre requête n'a pas l'air d'en recueillir."
164
+
165
+ history_transformer_format = history + [[message, ""]]
166
+
167
+ print(history_transformer_format)
168
+ return source_text, history_transformer_format
169
 
170
  # Define the Gradio interface
171
  title = "Tchap"
 
176
  0.7 # temperature
177
  ]
178
  ]
179
+
180
  with gr.Blocks() as demo:
181
+ chatbot = gr.Chatbot()
182
+ msg = gr.Textbox()
183
+ clear = gr.Button("Clear")
184
+
185
+ user_output = gr.HTML() # To display the user's message
186
+ history = gr.State()
187
+
188
+ msg.submit(user, inputs=[msg, history], outputs=[user_output, history], queue=False).then(
189
+ predict, chatbot, chatbot
190
+ )
191
+
192
+ clear.click(lambda: None, None, chatbot, queue=False)
193
+
194
 
195
+ demo.queue()
196
+ demo.launch()