mtyrrell commited on
Commit
ec4377c
·
1 Parent(s): f346328

updated for streaming

Browse files
Files changed (2) hide show
  1. app.py +13 -11
  2. utils/generator.py +77 -10
app.py CHANGED
@@ -1,13 +1,11 @@
1
  import gradio as gr
2
- from utils.generator import generate
3
 
4
  # ---------------------------------------------------------------------
5
- # Gradio Interface with MCP support
6
  # ---------------------------------------------------------------------
7
-
8
-
9
  ui = gr.Interface(
10
- fn=generate,
11
  inputs=[
12
  gr.Textbox(
13
  label="Query",
@@ -22,10 +20,15 @@ ui = gr.Interface(
22
  info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
23
  ),
24
  ],
25
- outputs=[gr.Text(label="Generated Answer", lines=6, show_copy_button=True)],
26
- title="ChatFed Generation Module",
27
- description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
28
- api_name="generate"
 
 
 
 
 
29
  )
30
 
31
  # Launch with MCP server enabled
@@ -33,7 +36,6 @@ if __name__ == "__main__":
33
  ui.launch(
34
  server_name="0.0.0.0",
35
  server_port=7860,
36
- #mcp_server=True,
37
  show_error=True
38
  )
39
-
 
1
  import gradio as gr
2
+ from .generator import generate, generate_streaming
3
 
4
  # ---------------------------------------------------------------------
5
+ # Gradio Interface with MCP support and streaming
6
  # ---------------------------------------------------------------------
 
 
7
  ui = gr.Interface(
8
+ fn=generate_streaming, # Use streaming function
9
  inputs=[
10
  gr.Textbox(
11
  label="Query",
 
20
  info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
21
  ),
22
  ],
23
+ outputs=gr.Textbox(
24
+ label="Generated Answer",
25
+ lines=6,
26
+ show_copy_button=True,
27
+ streaming=True # Enable streaming in the output
28
+ ),
29
+ title="ChatFed Generation Module",
30
+ description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
31
+ api_name="generate"
32
  )
33
 
34
  # Launch with MCP server enabled
 
36
  ui.launch(
37
  server_name="0.0.0.0",
38
  server_port=7860,
39
+ mcp_server=True,
40
  show_error=True
41
  )
 
utils/generator.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import asyncio
3
  import json
4
  import ast
5
- from typing import List, Dict, Any, Union
6
  from dotenv import load_dotenv
7
 
8
  # LangChain imports
@@ -24,8 +24,6 @@ PROVIDER = config.get("generator", "PROVIDER")
24
  MODEL = config.get("generator", "MODEL")
25
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
26
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
27
- INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
28
- ORGANIZATION = config.get("generator", "ORGANIZATION")
29
 
30
  # Set up authentication for the selected provider
31
  auth_config = get_auth(PROVIDER)
@@ -41,18 +39,21 @@ def get_chat_model():
41
  return ChatOpenAI(
42
  model=MODEL,
43
  openai_api_key=auth_config["api_key"],
 
44
  **common_params
45
  )
46
  elif PROVIDER == "anthropic":
47
  return ChatAnthropic(
48
  model=MODEL,
49
  anthropic_api_key=auth_config["api_key"],
 
50
  **common_params
51
  )
52
  elif PROVIDER == "cohere":
53
  return ChatCohere(
54
  model=MODEL,
55
  cohere_api_key=auth_config["api_key"],
 
56
  **common_params
57
  )
58
  elif PROVIDER == "huggingface":
@@ -61,10 +62,9 @@ def get_chat_model():
61
  repo_id=MODEL,
62
  huggingfacehub_api_token=auth_config["api_key"],
63
  task="text-generation",
64
- provider=INFERENCE_PROVIDER,
65
- server_kwargs={"bill_to": ORGANIZATION},
66
  temperature=TEMPERATURE,
67
- max_new_tokens=MAX_TOKENS
 
68
  )
69
  return ChatHuggingFace(llm=llm)
70
  else:
@@ -143,7 +143,7 @@ def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str:
143
  # ---------------------------------------------------------------------
144
  async def _call_llm(messages: list) -> str:
145
  """
146
- Provider-agnostic LLM call using LangChain.
147
 
148
  Args:
149
  messages: List of LangChain message objects
@@ -159,6 +159,25 @@ async def _call_llm(messages: list) -> str:
159
  logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
160
  raise
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def build_messages(question: str, context: str) -> list:
163
  """
164
  Build messages in LangChain format.
@@ -222,9 +241,57 @@ async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str
222
  try:
223
  messages = build_messages(query, formatted_context)
224
  answer = await _call_llm(messages)
225
-
226
  return answer
227
-
228
  except Exception as e:
229
  logging.exception("Generation failed")
230
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import asyncio
3
  import json
4
  import ast
5
+ from typing import List, Dict, Any, Union, Generator, AsyncGenerator
6
  from dotenv import load_dotenv
7
 
8
  # LangChain imports
 
24
  MODEL = config.get("generator", "MODEL")
25
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
26
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
 
 
27
 
28
  # Set up authentication for the selected provider
29
  auth_config = get_auth(PROVIDER)
 
39
  return ChatOpenAI(
40
  model=MODEL,
41
  openai_api_key=auth_config["api_key"],
42
+ streaming=True, # Enable streaming
43
  **common_params
44
  )
45
  elif PROVIDER == "anthropic":
46
  return ChatAnthropic(
47
  model=MODEL,
48
  anthropic_api_key=auth_config["api_key"],
49
+ streaming=True, # Enable streaming
50
  **common_params
51
  )
52
  elif PROVIDER == "cohere":
53
  return ChatCohere(
54
  model=MODEL,
55
  cohere_api_key=auth_config["api_key"],
56
+ streaming=True, # Enable streaming
57
  **common_params
58
  )
59
  elif PROVIDER == "huggingface":
 
62
  repo_id=MODEL,
63
  huggingfacehub_api_token=auth_config["api_key"],
64
  task="text-generation",
 
 
65
  temperature=TEMPERATURE,
66
+ max_new_tokens=MAX_TOKENS,
67
+ streaming=True # Enable streaming
68
  )
69
  return ChatHuggingFace(llm=llm)
70
  else:
 
143
  # ---------------------------------------------------------------------
144
  async def _call_llm(messages: list) -> str:
145
  """
146
+ Provider-agnostic LLM call using LangChain (non-streaming).
147
 
148
  Args:
149
  messages: List of LangChain message objects
 
159
  logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
160
  raise
161
 
162
+ async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
163
+ """
164
+ Provider-agnostic streaming LLM call using LangChain.
165
+
166
+ Args:
167
+ messages: List of LangChain message objects
168
+
169
+ Yields:
170
+ Generated response chunks as strings
171
+ """
172
+ try:
173
+ # Use async stream for streaming responses
174
+ async for chunk in chat_model.astream(messages):
175
+ if hasattr(chunk, 'content') and chunk.content:
176
+ yield chunk.content
177
+ except Exception as e:
178
+ logging.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
179
+ yield f"Error: {str(e)}"
180
+
181
  def build_messages(question: str, context: str) -> list:
182
  """
183
  Build messages in LangChain format.
 
241
  try:
242
  messages = build_messages(query, formatted_context)
243
  answer = await _call_llm(messages)
 
244
  return answer
 
245
  except Exception as e:
246
  logging.exception("Generation failed")
247
+ return f"Error: {str(e)}"
248
+
249
+ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]]) -> AsyncGenerator[str, None]:
250
+ """
251
+ Generate a streaming answer to a query using provided context through RAG.
252
+
253
+ This function takes a user query and relevant context, then uses a language model
254
+ to generate a streaming answer based on the provided information.
255
+
256
+ Args:
257
+ query (str): User query
258
+ context (Union[str, List[Dict[str, Any]]]): Context as string or list of retrieval results
259
+
260
+ Yields:
261
+ str: Streaming chunks of the generated answer
262
+ """
263
+ if not query.strip():
264
+ yield "Error: Query cannot be empty"
265
+ return
266
+
267
+ # Handle both string context (for Gradio UI) and list context (from retriever)
268
+ if isinstance(context, list):
269
+ if not context:
270
+ yield "Error: No retrieval results provided"
271
+ return
272
+
273
+ # Process the retrieval results
274
+ processed_results = extract_relevant_fields(context)
275
+ formatted_context = format_context_from_results(processed_results)
276
+
277
+ if not formatted_context.strip():
278
+ yield "Error: No valid content found in retrieval results"
279
+ return
280
+
281
+ elif isinstance(context, str):
282
+ if not context.strip():
283
+ yield "Error: Context cannot be empty"
284
+ return
285
+ formatted_context = context
286
+
287
+ else:
288
+ yield "Error: Context must be either a string or list of retrieval results"
289
+ return
290
+
291
+ try:
292
+ messages = build_messages(query, formatted_context)
293
+ async for chunk in _call_llm_streaming(messages):
294
+ yield chunk
295
+ except Exception as e:
296
+ logging.exception("Streaming generation failed")
297
+ yield f"Error: {str(e)}"