Aston-xMAD commited on
Commit
177a926
1 Parent(s): b37c16f

enabled streaming

Browse files
Files changed (1) hide show
  1. app.py +56 -1
app.py CHANGED
@@ -165,8 +165,63 @@ def chatbot_response(message, history):
165
  return response + metrics
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  demo = gr.ChatInterface(
169
- fn=chatbot_response,
170
  examples=["Hello", "How are you?", "Tell me a joke"],
171
  title="Chat with xMAD's: 1-bit-Llama-3-8B-Instruct Model",
172
  description="Contact support@xmad.ai to set up a demo",
 
165
  return response + metrics
166
 
167
 
168
+ def process_dialog_streaming(message, history):
169
+ terminators = [
170
+ tokenizer.eos_token_id,
171
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
172
+ ]
173
+
174
+ dialog = [
175
+ {"role": "user" if i % 2 == 0 else "assistant", "content": msg}
176
+ for i, (msg, _) in enumerate(history)
177
+ ]
178
+ dialog.append({"role": "user", "content": message})
179
+
180
+ prompt = tokenizer.apply_chat_template(
181
+ dialog, tokenize=False, add_generation_prompt=True
182
+ )
183
+ tokenized_input_prompt_ids = tokenizer(
184
+ prompt, return_tensors="pt"
185
+ ).input_ids.to(model.device)
186
+
187
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
188
+
189
+ generation_kwargs = dict(
190
+ inputs=tokenized_input_prompt_ids,
191
+ streamer=streamer,
192
+ max_new_tokens=512,
193
+ temperature=0.4,
194
+ do_sample=True,
195
+ eos_token_id=terminators,
196
+ pad_token_id=tokenizer.pad_token_id,
197
+ )
198
+
199
+ start_time = time.time()
200
+ total_tokens = 0
201
+
202
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
203
+ thread.start()
204
+
205
+ generated_text = ""
206
+ for new_text in streamer:
207
+ generated_text += new_text
208
+ total_tokens += 1
209
+ current_time = time.time()
210
+ elapsed_time = current_time - start_time
211
+ tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0
212
+ print(f"Tokens per second: {tokens_per_second:.2f}", end="\r")
213
+ yield generated_text, elapsed_time, tokens_per_second
214
+
215
+ thread.join()
216
+
217
+ def chatbot_response_streaming(message, history):
218
+ for response, generation_time, tokens_per_second in process_dialog_streaming(message, history):
219
+ metrics = f"\n\n---\n\n **Metrics**\t*Answer Generation Time:* `{generation_time:.2f} sec`\t*Tokens per Second:* `{tokens_per_second:.2f}`\n\n"
220
+ yield response + metrics
221
+
222
+
223
  demo = gr.ChatInterface(
224
+ fn=chatbot_response_streaming,
225
  examples=["Hello", "How are you?", "Tell me a joke"],
226
  title="Chat with xMAD's: 1-bit-Llama-3-8B-Instruct Model",
227
  description="Contact support@xmad.ai to set up a demo",