mtyrrell commited on
Commit
f256208
·
1 Parent(s): d049b68
Files changed (1) hide show
  1. utils/generator.py +129 -342
utils/generator.py CHANGED
@@ -3,7 +3,7 @@ import asyncio
3
  import json
4
  import ast
5
  import re
6
- from typing import List, Dict, Any, Union, Generator, AsyncGenerator
7
  from dotenv import load_dotenv
8
 
9
  # LangChain imports
@@ -17,188 +17,150 @@ from langchain_core.messages import SystemMessage, HumanMessage
17
  from .utils import getconfig, get_auth
18
 
19
  # ---------------------------------------------------------------------
20
- # Model / client initialization (non exaustive list of providers)
21
  # ---------------------------------------------------------------------
22
  config = getconfig("params.cfg")
23
-
24
  PROVIDER = config.get("generator", "PROVIDER")
25
  MODEL = config.get("generator", "MODEL")
26
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
27
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
28
 
29
- # Set up authentication for the selected provider
30
  auth_config = get_auth(PROVIDER)
 
31
 
32
- def get_chat_model():
33
  """Initialize the appropriate LangChain chat model based on provider"""
34
- common_params = {
35
- "temperature": TEMPERATURE,
36
- "max_tokens": MAX_TOKENS,
 
 
 
 
 
 
 
37
  }
38
 
39
- if PROVIDER == "openai":
40
- return ChatOpenAI(
41
- model=MODEL,
42
- openai_api_key=auth_config["api_key"],
43
- streaming=True, # Enable streaming
44
- **common_params
45
- )
46
- elif PROVIDER == "anthropic":
47
- return ChatAnthropic(
48
- model=MODEL,
49
- anthropic_api_key=auth_config["api_key"],
50
- streaming=True, # Enable streaming
51
- **common_params
52
- )
53
- elif PROVIDER == "cohere":
54
- return ChatCohere(
55
- model=MODEL,
56
- cohere_api_key=auth_config["api_key"],
57
- streaming=True, # Enable streaming
58
- **common_params
59
- )
60
- elif PROVIDER == "huggingface":
61
- # Initialize HuggingFaceEndpoint with explicit parameters
62
- llm = HuggingFaceEndpoint(
63
- repo_id=MODEL,
64
- huggingfacehub_api_token=auth_config["api_key"],
65
- task="text-generation",
66
- temperature=TEMPERATURE,
67
- max_new_tokens=MAX_TOKENS,
68
- streaming=True # Enable streaming
69
- )
70
- return ChatHuggingFace(llm=llm)
71
- else:
72
  raise ValueError(f"Unsupported provider: {PROVIDER}")
73
-
74
- # Initialize provider-agnostic chat model
75
- chat_model = get_chat_model()
76
 
77
  # ---------------------------------------------------------------------
78
- # Citation parsing and source filtering
79
  # ---------------------------------------------------------------------
80
- def parse_citations_from_response(response: str) -> List[int]:
81
- """
82
- Parse citation numbers from the generated response.
83
-
84
- Args:
85
- response: The generated response text
86
-
87
- Returns:
88
- List of unique citation numbers found in the response
89
- """
90
- # Find all citation patterns like [1], [2], [1][2], etc.
91
  citation_pattern = r'\[(\d+)\]'
92
  matches = re.findall(citation_pattern, response)
93
-
94
- # Convert to integers and return unique values
95
- citation_numbers = [int(match) for match in matches]
96
- return sorted(list(set(citation_numbers)))
97
 
98
- def filter_sources_by_citations(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
99
- """
100
- Filter sources to only include those that were cited in the response.
101
-
102
- Args:
103
- processed_results: All processed retrieval results
104
- cited_numbers: List of citation numbers found in the response
105
-
106
- Returns:
107
- List of sources that were actually cited
108
- """
109
  if not cited_numbers:
110
  return []
111
 
112
- # Filter sources based on citation numbers (1-indexed)
113
  cited_sources = []
114
  for citation_num in cited_numbers:
115
- # Convert to 0-indexed for list access
116
  source_index = citation_num - 1
117
  if 0 <= source_index < len(processed_results):
118
  cited_sources.append(processed_results[source_index])
119
 
120
  return cited_sources
121
 
122
- # ---------------------------------------------------------------------
123
- # Context processing - may need further refinement (i.e. to manage other data sources)
124
- # ---------------------------------------------------------------------
125
- def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
126
- """
127
- Extract only relevant fields from retrieval results.
128
-
129
- Args:
130
- retrieval_results: List of JSON objects from retriever
131
-
132
- Returns:
133
- List of processed objects with only relevant fields
134
- """
135
- if isinstance(retrieval_results, str):
136
- retrieval_results = ast.literal_eval(retrieval_results)
137
-
138
  processed_results = []
139
 
140
- for result in retrieval_results:
141
- # Extract the answer content
142
- answer = result.get('answer', '')
143
 
144
- # Extract document identification from metadata
145
- metadata = result.get('answer_metadata', {})
146
- doc_info = {
147
- 'answer': answer,
148
- 'filename': metadata.get('filename', 'Unknown'),
149
- 'page': metadata.get('page', 'Unknown'),
150
- 'year': metadata.get('year', 'Unknown'),
151
- 'source': metadata.get('source', 'Unknown'),
152
- 'document_id': metadata.get('_id', 'Unknown')
153
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- processed_results.append(doc_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- return processed_results
 
158
 
159
- def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str:
160
- """
161
- Format processed retrieval results into a context string for the LLM.
162
-
163
- Args:
164
- processed_results: List of processed objects with relevant fields
 
165
 
166
- Returns:
167
- Formatted context string
168
- """
169
- if not processed_results:
170
- return ""
171
-
172
- context_parts = []
173
-
174
- for i, result in enumerate(processed_results, 1):
175
- doc_reference = f"[Document {i}: {result['filename']}"
176
- if result['page'] != 'Unknown':
177
- doc_reference += f", Page {result['page']}"
178
- if result['year'] != 'Unknown':
179
- doc_reference += f", Year {result['year']}"
180
- doc_reference += "]"
181
 
182
- context_part = f"{doc_reference}\n{result['answer']}\n"
183
- context_parts.append(context_part)
184
 
185
- return "\n".join(context_parts)
186
 
187
  # ---------------------------------------------------------------------
188
- # Core generation function for both Gradio UI and MCP
189
  # ---------------------------------------------------------------------
190
  async def _call_llm(messages: list) -> str:
191
- """
192
- Provider-agnostic LLM call using LangChain (non-streaming).
193
-
194
- Args:
195
- messages: List of LangChain message objects
196
-
197
- Returns:
198
- Generated response content as string
199
- """
200
  try:
201
- # Use async invoke for better performance
202
  response = await chat_model.ainvoke(messages)
203
  return response.content.strip()
204
  except Exception as e:
@@ -206,17 +168,8 @@ async def _call_llm(messages: list) -> str:
206
  raise
207
 
208
  async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
209
- """
210
- Provider-agnostic streaming LLM call using LangChain.
211
-
212
- Args:
213
- messages: List of LangChain message objects
214
-
215
- Yields:
216
- Generated response chunks as strings
217
- """
218
  try:
219
- # Use async stream for streaming responses
220
  async for chunk in chat_model.astream(messages):
221
  if hasattr(chunk, 'content') and chunk.content:
222
  yield chunk.content
@@ -224,191 +177,50 @@ async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
224
  logging.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
225
  yield f"Error: {str(e)}"
226
 
227
- def build_messages(question: str, context: str) -> list:
228
- """
229
- Build messages in LangChain format.
230
-
231
- Args:
232
- question: The user's question
233
- context: The relevant context for answering
234
-
235
- Returns:
236
- List of LangChain message objects
237
- """
238
- system_content = """
239
- You are AuditQ&A, an AI Assistant created by Auditors and Data Scientist. \
240
- You are given a question and extracted passages of the consolidated/departmental/thematic focus audit reports.\
241
- Provide a clear and structured answer based on the passages/context provided and the guidelines.
242
- Guidelines:
243
- - If the passages have useful facts or numbers, use them in your answer.
244
- - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
245
- - If it makes sense, use bullet points and lists to make your answers easier to understand.
246
- - You do not need to use every passage. Only use the ones that help answer the question.
247
- - Answer the USER question using only the CONTEXT provided.
248
- - When referencing information from the context, use inline citations in square brackets like [1], [2], etc. to reference the document numbers shown in the context.
249
- - Use multiple citations when information comes from multiple documents, like [1][2].
250
- - Do not use the sentence 'Doc x says ...' to say where information came from, but rather just include the citation at the end of the sentence.
251
- - If the context is insufficient, say "I don't have sufficient information to answer the question. Please try rephrasing your query."
252
- """
253
-
254
- user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
255
-
256
- return [
257
- SystemMessage(content=system_content),
258
- HumanMessage(content=user_content)
259
- ]
260
-
261
  async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> Union[str, Dict[str, Any]]:
262
- """
263
- Generate an answer to a query using provided context through RAG.
264
-
265
- This function takes a user query and relevant context, then uses a language model
266
- to generate a comprehensive answer based on the provided information.
267
-
268
- Args:
269
- query (str): User query
270
- context (Union[str, List[Dict[str, Any]]]): Context as string or list of retrieval results
271
- chatui_format (bool): If True, return ChatUI format with sources
272
-
273
- Returns:
274
- Union[str, Dict]: The generated answer or ChatUI format response
275
- """
276
  if not query.strip():
277
- return {"error": "Query cannot be empty"} if chatui_format else "Error: Query cannot be empty"
278
-
279
- processed_results = []
280
-
281
- # Handle both string context (for Gradio UI) and list context (from retriever)
282
- if isinstance(context, list):
283
- if not context:
284
- return {"error": "No retrieval results provided"} if chatui_format else "Error: No retrieval results provided"
285
-
286
- # Process the retrieval results
287
- processed_results = extract_relevant_fields(context)
288
- formatted_context = format_context_from_results(processed_results)
289
-
290
- if not formatted_context.strip():
291
- return {"error": "No valid content found in retrieval results"} if chatui_format else "Error: No valid content found in retrieval results"
292
-
293
- elif isinstance(context, str):
294
- if not context.strip():
295
- return {"error": "Context cannot be empty"} if chatui_format else "Error: Context cannot be empty"
296
- formatted_context = context
297
-
298
- else:
299
- return {"error": "Context must be either a string or list of retrieval results"} if chatui_format else "Error: Context must be either a string or list of retrieval results"
300
 
301
  try:
302
- messages = build_messages(query, formatted_context)
 
303
  answer = await _call_llm(messages)
304
 
305
  if chatui_format:
306
- # Return ChatUI format
307
  result = {"answer": answer}
308
  if processed_results:
309
- # Parse citations from the response
310
- cited_numbers = parse_citations_from_response(answer)
311
-
312
- # Filter sources to only include cited ones
313
- cited_sources = filter_sources_by_citations(processed_results, cited_numbers)
314
-
315
- # Extract sources for ChatUI
316
- sources = []
317
- for result_item in cited_sources: # Only cited sources
318
- filename = result_item.get('filename', 'Unknown')
319
- page = result_item.get('page', 'Unknown')
320
- year = result_item.get('year', 'Unknown')
321
-
322
- # Create link using doc:// scheme
323
- link = f"doc://{filename}"
324
-
325
- # Create descriptive title
326
- title_parts = [filename]
327
- if page != 'Unknown':
328
- title_parts.append(f"Page {page}")
329
- if year != 'Unknown':
330
- title_parts.append(f"({year})")
331
-
332
- title = " - ".join(title_parts)
333
-
334
- sources.append({
335
- "link": link,
336
- "title": title
337
- })
338
-
339
- result["sources"] = sources
340
  return result
341
  else:
342
  return answer
343
 
344
  except Exception as e:
345
  logging.exception("Generation failed")
346
- return {"error": str(e)} if chatui_format else f"Error: {str(e)}"
 
347
 
348
  async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
349
- """
350
- Generate a streaming answer to a query using provided context through RAG.
351
-
352
- This function takes a user query and relevant context, then uses a language model
353
- to generate a streaming answer based on the provided information.
354
-
355
- Args:
356
- query (str): User query
357
- context (Union[str, List[Dict[str, Any]]]): Context as string or list of retrieval results
358
- chatui_format (bool): If True, yield ChatUI format events
359
-
360
- Yields:
361
- Union[str, Dict]: Streaming chunks or ChatUI format events
362
- """
363
  if not query.strip():
 
364
  if chatui_format:
365
- yield {"event": "error", "data": {"error": "Query cannot be empty"}}
366
  else:
367
- yield "Error: Query cannot be empty"
368
- return
369
-
370
- processed_results = []
371
-
372
- # Handle both string context (for Gradio UI) and list context (from retriever)
373
- if isinstance(context, list):
374
- if not context:
375
- if chatui_format:
376
- yield {"event": "error", "data": {"error": "No retrieval results provided"}}
377
- else:
378
- yield "Error: No retrieval results provided"
379
- return
380
-
381
- # Process the retrieval results
382
- processed_results = extract_relevant_fields(context)
383
- formatted_context = format_context_from_results(processed_results)
384
-
385
- if not formatted_context.strip():
386
- if chatui_format:
387
- yield {"event": "error", "data": {"error": "No valid content found in retrieval results"}}
388
- else:
389
- yield "Error: No valid content found in retrieval results"
390
- return
391
-
392
- elif isinstance(context, str):
393
- if not context.strip():
394
- if chatui_format:
395
- yield {"event": "error", "data": {"error": "Context cannot be empty"}}
396
- else:
397
- yield "Error: Context cannot be empty"
398
- return
399
- formatted_context = context
400
-
401
- else:
402
- if chatui_format:
403
- yield {"event": "error", "data": {"error": "Context must be either a string or list of retrieval results"}}
404
- else:
405
- yield "Error: Context must be either a string or list of retrieval results"
406
  return
407
 
408
  try:
409
- messages = build_messages(query, formatted_context)
 
410
 
411
- # Stream the text response and accumulate it for citation parsing
412
  accumulated_response = ""
413
  async for chunk in _call_llm_streaming(messages):
414
  accumulated_response += chunk
@@ -419,44 +231,19 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
419
 
420
  # Send sources at the end if available and in ChatUI format
421
  if chatui_format and processed_results:
422
- # Parse citations from the complete response
423
- cited_numbers = parse_citations_from_response(accumulated_response)
424
-
425
- # Filter sources to only include cited ones
426
- cited_sources = filter_sources_by_citations(processed_results, cited_numbers)
427
-
428
- sources = []
429
- for result in cited_sources: # Only cited sources
430
- filename = result.get('filename', 'Unknown')
431
- page = result.get('page', 'Unknown')
432
- year = result.get('year', 'Unknown')
433
-
434
- # Create link using doc:// scheme
435
- link = f"doc://{filename}"
436
-
437
- # Create descriptive title
438
- title_parts = [filename]
439
- if page != 'Unknown':
440
- title_parts.append(f"Page {page}")
441
- if year != 'Unknown':
442
- title_parts.append(f"({year})")
443
-
444
- title = " - ".join(title_parts)
445
-
446
- sources.append({
447
- "link": link,
448
- "title": title
449
- })
450
-
451
  yield {"event": "sources", "data": {"sources": sources}}
452
 
453
- # Send end event for ChatUI format
454
  if chatui_format:
455
  yield {"event": "end", "data": {}}
456
 
457
  except Exception as e:
458
  logging.exception("Streaming generation failed")
 
459
  if chatui_format:
460
- yield {"event": "error", "data": {"error": str(e)}}
461
  else:
462
- yield f"Error: {str(e)}"
 
3
  import json
4
  import ast
5
  import re
6
+ from typing import List, Dict, Any, Union, AsyncGenerator
7
  from dotenv import load_dotenv
8
 
9
  # LangChain imports
 
17
  from .utils import getconfig, get_auth
18
 
19
  # ---------------------------------------------------------------------
20
+ # Configuration and Model Initialization
21
  # ---------------------------------------------------------------------
22
  config = getconfig("params.cfg")
 
23
  PROVIDER = config.get("generator", "PROVIDER")
24
  MODEL = config.get("generator", "MODEL")
25
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
26
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
27
 
28
+ # Initialize chat model
29
  auth_config = get_auth(PROVIDER)
30
+ chat_model = _get_chat_model()
31
 
32
+ def _get_chat_model():
33
  """Initialize the appropriate LangChain chat model based on provider"""
34
+ common_params = {"temperature": TEMPERATURE, "max_tokens": MAX_TOKENS}
35
+
36
+ providers = {
37
+ "openai": lambda: ChatOpenAI(model=MODEL, openai_api_key=auth_config["api_key"], streaming=True, **common_params),
38
+ "anthropic": lambda: ChatAnthropic(model=MODEL, anthropic_api_key=auth_config["api_key"], streaming=True, **common_params),
39
+ "cohere": lambda: ChatCohere(model=MODEL, cohere_api_key=auth_config["api_key"], streaming=True, **common_params),
40
+ "huggingface": lambda: ChatHuggingFace(llm=HuggingFaceEndpoint(
41
+ repo_id=MODEL, huggingfacehub_api_token=auth_config["api_key"],
42
+ task="text-generation", temperature=TEMPERATURE, max_new_tokens=MAX_TOKENS, streaming=True
43
+ ))
44
  }
45
 
46
+ if PROVIDER not in providers:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  raise ValueError(f"Unsupported provider: {PROVIDER}")
48
+
49
+ return providers[PROVIDER]()
 
50
 
51
  # ---------------------------------------------------------------------
52
+ # Core Processing Functions
53
  # ---------------------------------------------------------------------
54
+ def _parse_citations(response: str) -> List[int]:
55
+ """Parse citation numbers from response text"""
 
 
 
 
 
 
 
 
 
56
  citation_pattern = r'\[(\d+)\]'
57
  matches = re.findall(citation_pattern, response)
58
+ return sorted(list(set(int(match) for match in matches)))
 
 
 
59
 
60
+ def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
61
+ """Extract sources that were cited in the response"""
 
 
 
 
 
 
 
 
 
62
  if not cited_numbers:
63
  return []
64
 
 
65
  cited_sources = []
66
  for citation_num in cited_numbers:
 
67
  source_index = citation_num - 1
68
  if 0 <= source_index < len(processed_results):
69
  cited_sources.append(processed_results[source_index])
70
 
71
  return cited_sources
72
 
73
+ def _process_context(context: Union[str, List[Dict[str, Any]]]) -> tuple[str, List[Dict[str, Any]]]:
74
+ """Process context and return formatted context string and processed results"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  processed_results = []
76
 
77
+ if isinstance(context, list):
78
+ if not context:
79
+ raise ValueError("No retrieval results provided")
80
 
81
+ # Extract relevant fields from retrieval results
82
+ for result in context:
83
+ if isinstance(result, str):
84
+ result = ast.literal_eval(result)
85
+
86
+ metadata = result.get('answer_metadata', {})
87
+ doc_info = {
88
+ 'answer': result.get('answer', ''),
89
+ 'filename': metadata.get('filename', 'Unknown'),
90
+ 'page': metadata.get('page', 'Unknown'),
91
+ 'year': metadata.get('year', 'Unknown'),
92
+ 'source': metadata.get('source', 'Unknown'),
93
+ 'document_id': metadata.get('_id', 'Unknown')
94
+ }
95
+ processed_results.append(doc_info)
96
+
97
+ # Format context string
98
+ context_parts = []
99
+ for i, result in enumerate(processed_results, 1):
100
+ doc_ref = f"[Document {i}: {result['filename']}"
101
+ if result['page'] != 'Unknown':
102
+ doc_ref += f", Page {result['page']}"
103
+ if result['year'] != 'Unknown':
104
+ doc_ref += f", Year {result['year']}"
105
+ doc_ref += "]"
106
+ context_parts.append(f"{doc_ref}\n{result['answer']}\n")
107
+
108
+ formatted_context = "\n".join(context_parts)
109
 
110
+ elif isinstance(context, str):
111
+ if not context.strip():
112
+ raise ValueError("Context cannot be empty")
113
+ formatted_context = context
114
+ else:
115
+ raise ValueError("Context must be either a string or list of retrieval results")
116
+
117
+ return formatted_context, processed_results
118
+
119
+ def _build_messages(question: str, context: str) -> list:
120
+ """Build messages in LangChain format"""
121
+ system_content = """You are AuditQ&A, an AI Assistant created by Auditors and Data Scientist. \
122
+ You are given a question and extracted passages of the consolidated/departmental/thematic focus audit reports.\
123
+ Provide a clear and structured answer based on the passages/context provided and the guidelines.
124
+ Guidelines:
125
+ - If the passages have useful facts or numbers, use them in your answer.
126
+ - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
127
+ - If it makes sense, use bullet points and lists to make your answers easier to understand.
128
+ - You do not need to use every passage. Only use the ones that help answer the question.
129
+ - Answer the USER question using only the CONTEXT provided.
130
+ - When referencing information from the context, use inline citations in square brackets like [1], [2], etc. to reference the document numbers shown in the context.
131
+ - Use multiple citations when information comes from multiple documents, like [1][2].
132
+ - Do not use the sentence 'Doc x says ...' to say where information came from, but rather just include the citation at the end of the sentence.
133
+ - If the context is insufficient, say "I don't have sufficient information to answer the question. Please try rephrasing your query."
134
+ """
135
 
136
+ user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
137
+ return [SystemMessage(content=system_content), HumanMessage(content=user_content)]
138
 
139
+ def _create_sources_list(cited_sources: List[Dict[str, Any]]) -> List[Dict[str, str]]:
140
+ """Create sources list for ChatUI format"""
141
+ sources = []
142
+ for result in cited_sources:
143
+ filename = result.get('filename', 'Unknown')
144
+ page = result.get('page', 'Unknown')
145
+ year = result.get('year', 'Unknown')
146
 
147
+ link = f"doc://{filename}"
148
+ title_parts = [filename]
149
+ if page != 'Unknown':
150
+ title_parts.append(f"Page {page}")
151
+ if year != 'Unknown':
152
+ title_parts.append(f"({year})")
 
 
 
 
 
 
 
 
 
153
 
154
+ sources.append({"link": link, "title": " - ".join(title_parts)})
 
155
 
156
+ return sources
157
 
158
  # ---------------------------------------------------------------------
159
+ # LLM Call Functions
160
  # ---------------------------------------------------------------------
161
  async def _call_llm(messages: list) -> str:
162
+ """Provider-agnostic LLM call using LangChain (non-streaming)"""
 
 
 
 
 
 
 
 
163
  try:
 
164
  response = await chat_model.ainvoke(messages)
165
  return response.content.strip()
166
  except Exception as e:
 
168
  raise
169
 
170
  async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
171
+ """Provider-agnostic streaming LLM call using LangChain"""
 
 
 
 
 
 
 
 
172
  try:
 
173
  async for chunk in chat_model.astream(messages):
174
  if hasattr(chunk, 'content') and chunk.content:
175
  yield chunk.content
 
177
  logging.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
178
  yield f"Error: {str(e)}"
179
 
180
+ # ---------------------------------------------------------------------
181
+ # Main Generation Functions
182
+ # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> Union[str, Dict[str, Any]]:
184
+ """Generate an answer to a query using provided context through RAG"""
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  if not query.strip():
186
+ error_msg = "Query cannot be empty"
187
+ return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  try:
190
+ formatted_context, processed_results = _process_context(context)
191
+ messages = _build_messages(query, formatted_context)
192
  answer = await _call_llm(messages)
193
 
194
  if chatui_format:
 
195
  result = {"answer": answer}
196
  if processed_results:
197
+ cited_numbers = _parse_citations(answer)
198
+ cited_sources = _extract_sources(processed_results, cited_numbers)
199
+ result["sources"] = _create_sources_list(cited_sources)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  return result
201
  else:
202
  return answer
203
 
204
  except Exception as e:
205
  logging.exception("Generation failed")
206
+ error_msg = str(e)
207
+ return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
208
 
209
  async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
210
+ """Generate a streaming answer to a query using provided context through RAG"""
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  if not query.strip():
212
+ error_msg = "Query cannot be empty"
213
  if chatui_format:
214
+ yield {"event": "error", "data": {"error": error_msg}}
215
  else:
216
+ yield f"Error: {error_msg}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  return
218
 
219
  try:
220
+ formatted_context, processed_results = _process_context(context)
221
+ messages = _build_messages(query, formatted_context)
222
 
223
+ # Stream the response and accumulate for citation parsing (filter out any sources that were not cited)
224
  accumulated_response = ""
225
  async for chunk in _call_llm_streaming(messages):
226
  accumulated_response += chunk
 
231
 
232
  # Send sources at the end if available and in ChatUI format
233
  if chatui_format and processed_results:
234
+ cited_numbers = _parse_citations(accumulated_response)
235
+ cited_sources = _extract_sources(processed_results, cited_numbers)
236
+ sources = _create_sources_list(cited_sources)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  yield {"event": "sources", "data": {"sources": sources}}
238
 
239
+ # Send END event for ChatUI format
240
  if chatui_format:
241
  yield {"event": "end", "data": {}}
242
 
243
  except Exception as e:
244
  logging.exception("Streaming generation failed")
245
+ error_msg = str(e)
246
  if chatui_format:
247
+ yield {"event": "error", "data": {"error": error_msg}}
248
  else:
249
+ yield f"Error: {error_msg}"