Spaces:
Runtime error
Upload app.py
Browse files# Implement Real-Time Streaming for Chat Responses
## Description
This PR introduces real-time streaming functionality to our chat interface., aiming to enhance the user experience by providing immediate, token-by-token responses.
## Changes
- Enabled streaming in the HuggingFaceEndpoint configuration
- Implemented an asynchronous streaming process using `astream()`
- Modified the chat function to yield partial results in real-time
- Updated Gradio setup to support streaming responses (set queue as False)
## Expected Behavior
- Responses should start appearing immediately after a question is asked
- Text should stream in smoothly, word by word or token by token
- The final response should be identical to the non-streaming version
## Technical Details
Key components of the implementation:
1. **Streaming Callback**: Implemented `StreamingStdOutCallbackHandler` for real-time token processing.
2. **LLM Configuration**: Added `streaming=True` to `HuggingFaceEndpoint` setup.
3. **Asynchronous Streaming**: Created `process_stream()` function to handle token-by-token response generation.
4. **Real-Time Updates**: Modified main loop to yield updates as they become available.
@@ -217,29 +217,37 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
217 |
|
218 |
##-----------------------getting inference endpoints------------------------------
|
219 |
|
|
|
220 |
callback = StreamingStdOutCallbackHandler()
|
221 |
|
|
|
222 |
llm_qa = HuggingFaceEndpoint(
|
223 |
endpoint_url=model_config.get('reader', 'ENDPOINT'),
|
224 |
max_new_tokens=512,
|
225 |
repetition_penalty=1.03,
|
226 |
timeout=70,
|
227 |
huggingfacehub_api_token=HF_token,
|
228 |
-
streaming=True,
|
229 |
-
callbacks=[callback]
|
230 |
)
|
231 |
|
|
|
232 |
chat_model = ChatHuggingFace(llm=llm_qa)
|
233 |
|
|
|
234 |
docs_html = []
|
235 |
for i, d in enumerate(context_retrieved, 1):
|
236 |
docs_html.append(make_html_source(d, i))
|
237 |
docs_html = "".join(docs_html)
|
238 |
|
|
|
239 |
answer_yet = ""
|
240 |
|
|
|
241 |
async def process_stream():
|
242 |
-
nonlocal answer_yet
|
|
|
|
|
243 |
async for chunk in chat_model.astream(messages):
|
244 |
token = chunk.content
|
245 |
answer_yet += token
|
@@ -247,9 +255,10 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
247 |
history[-1] = (query, parsed_answer)
|
248 |
yield [tuple(x) for x in history], docs_html
|
249 |
|
|
|
250 |
async for update in process_stream():
|
251 |
yield update
|
252 |
-
|
253 |
# #callbacks = [StreamingStdOutCallbackHandler()]
|
254 |
# llm_qa = HuggingFaceEndpoint(
|
255 |
# endpoint_url= model_config.get('reader','ENDPOINT'),
|
@@ -508,11 +517,13 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
|
|
508 |
# https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
|
509 |
(textbox
|
510 |
.submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
|
|
|
511 |
.then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
|
512 |
.then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
|
513 |
|
514 |
(examples_hidden
|
515 |
.change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
|
|
|
516 |
.then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
|
517 |
.then(finish_chat, None, [textbox], api_name="finish_chat_examples")
|
518 |
)
|
|
|
217 |
|
218 |
##-----------------------getting inference endpoints------------------------------
|
219 |
|
220 |
+
# Set up the streaming callback handler
|
221 |
callback = StreamingStdOutCallbackHandler()
|
222 |
|
223 |
+
# Initialize the HuggingFaceEndpoint with streaming enabled
|
224 |
llm_qa = HuggingFaceEndpoint(
|
225 |
endpoint_url=model_config.get('reader', 'ENDPOINT'),
|
226 |
max_new_tokens=512,
|
227 |
repetition_penalty=1.03,
|
228 |
timeout=70,
|
229 |
huggingfacehub_api_token=HF_token,
|
230 |
+
streaming=True, # Enable streaming for real-time token generation
|
231 |
+
callbacks=[callback] # Add the streaming callback handler
|
232 |
)
|
233 |
|
234 |
+
# Create a ChatHuggingFace instance with the streaming-enabled endpoint
|
235 |
chat_model = ChatHuggingFace(llm=llm_qa)
|
236 |
|
237 |
+
# Prepare the HTML for displaying source documents
|
238 |
docs_html = []
|
239 |
for i, d in enumerate(context_retrieved, 1):
|
240 |
docs_html.append(make_html_source(d, i))
|
241 |
docs_html = "".join(docs_html)
|
242 |
|
243 |
+
# Initialize the variable to store the accumulated answer
|
244 |
answer_yet = ""
|
245 |
|
246 |
+
# Define an asynchronous generator function to process the streaming response
|
247 |
async def process_stream():
|
248 |
+
# Without nonlocal, Python would create a new local variable answer_yet inside process_stream(), instead of modifying the one from the outer scope.
|
249 |
+
nonlocal answer_yet # Use the outer scope's answer_yet variable
|
250 |
+
# Iterate over the streaming response chunks
|
251 |
async for chunk in chat_model.astream(messages):
|
252 |
token = chunk.content
|
253 |
answer_yet += token
|
|
|
255 |
history[-1] = (query, parsed_answer)
|
256 |
yield [tuple(x) for x in history], docs_html
|
257 |
|
258 |
+
# Stream the response updates
|
259 |
async for update in process_stream():
|
260 |
yield update
|
261 |
+
|
262 |
# #callbacks = [StreamingStdOutCallbackHandler()]
|
263 |
# llm_qa = HuggingFaceEndpoint(
|
264 |
# endpoint_url= model_config.get('reader','ENDPOINT'),
|
|
|
517 |
# https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
|
518 |
(textbox
|
519 |
.submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
|
520 |
+
# queue must be set as False (default) so the process is not waiting for another to be finished
|
521 |
.then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
|
522 |
.then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
|
523 |
|
524 |
(examples_hidden
|
525 |
.change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
|
526 |
+
# queue must be set as False (default) so the process is not waiting for another to be finished
|
527 |
.then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
|
528 |
.then(finish_chat, None, [textbox], api_name="finish_chat_examples")
|
529 |
)
|