Spaces:
Running
Running
# rag_processor.py (Fixed Syntax Error AGAIN) | |
import time | |
import asyncio | |
import traceback | |
from typing import List, Dict, Any, Optional, Callable, Tuple | |
from langsmith import traceable | |
try: | |
import config | |
from services import retriever, openai_service | |
except ImportError: | |
print("Error: Failed to import config or services in rag_processor.py") | |
raise SystemExit("Failed imports in rag_processor.py") | |
PIPELINE_VALIDATE_GENERATE_GPT4O = "GPT-4o Validator + GPT-4o Synthesizer" | |
StatusCallback = Callable[[str], None] | |
# --- Step Functions --- | |
async def run_retrieval_step(query: str, n_retrieve: int, update_status: StatusCallback) -> List[Dict]: | |
update_status(f"1. 诪讗讞讝专 注讚 {n_retrieve} 驻住拽讗讜转 诪-Pinecone...") | |
start_time = time.time(); retrieved_docs = retriever.retrieve_documents(query_text=query, n_results=n_retrieve) | |
retrieval_time = time.time() - start_time; status_msg = f"讗讜讞讝专讜 {len(retrieved_docs)} 驻住拽讗讜转 讘-{retrieval_time:.2f} 砖谞讬讜转." | |
update_status(f"1. {status_msg}") | |
if not retrieved_docs: update_status("1. 诇讗 讗讜转专讜 诪住诪讻讬诐.") | |
return retrieved_docs | |
async def run_gpt4o_validation_filter_step( | |
docs_to_process: List[Dict], query: str, n_validate: int, update_status: StatusCallback | |
) -> List[Dict]: | |
if not docs_to_process: update_status("2. [GPT-4o] 讚讬诇讜讙 注诇 讗讬诪讜转 - 讗讬谉 驻住拽讗讜转."); return [] | |
validation_count = min(len(docs_to_process), n_validate) | |
update_status(f"2. [GPT-4o] 诪转讞讬诇 讗讬诪讜转 诪拽讘讬诇讬 ({validation_count} / {len(docs_to_process)} 驻住拽讗讜转)...") | |
validation_start_time = time.time() | |
tasks = [openai_service.validate_relevance_openai(doc, query, i) for i, doc in enumerate(docs_to_process[:validation_count])] | |
validation_results = await asyncio.gather(*tasks, return_exceptions=True) | |
passed_docs = []; passed_count, failed_validation_count, error_count = 0, 0, 0 | |
update_status("3. [GPT-4o] 住讬谞讜谉 驻住拽讗讜转 诇驻讬 转讜爪讗讜转 讗讬诪讜转...") | |
for i, res in enumerate(validation_results): | |
original_doc = docs_to_process[i] | |
if isinstance(res, Exception): print(f"GPT-4o Validation Exception doc {i}: {res}"); error_count += 1 | |
elif isinstance(res, dict) and 'validation' in res and 'paragraph_data' in res: | |
if res["validation"].get("contains_relevant_info") is True: | |
original_doc['validation_result'] = res["validation"]; passed_docs.append(original_doc); passed_count += 1 | |
else: failed_validation_count += 1 | |
else: print(f"GPT-4o Validation Unexpected result doc {i}: {type(res)}"); error_count += 1 | |
validation_time = time.time() - validation_start_time | |
status_msg_val = f"讗讬诪讜转 GPT-4o 讛讜砖诇诐 ({passed_count} 注讘专讜, {failed_validation_count} 谞讚讞讜, {error_count} 砖讙讬讗讜转) 讘-{validation_time:.2f} 砖谞讬讜转." | |
update_status(f"2. {status_msg_val}") | |
status_msg_filter = f"谞讗住驻讜 {len(passed_docs)} 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讗讞专 讗讬诪讜转 GPT-4o." | |
update_status(f"3. {status_msg_filter}") | |
return passed_docs # Returns full docs that passed | |
async def run_openai_generation_step( | |
history: List[Dict], context_documents: List[Dict], update_status: StatusCallback, stream_callback: Callable[[str], None] | |
) -> Tuple[str, Optional[str]]: | |
generator_name = "OpenAI"; | |
if not context_documents: | |
update_status(f"4. [{generator_name}] 讚讬诇讜讙 注诇 讬爪讬专讛 - 讗讬谉 驻住拽讗讜转 诇讛拽砖专.") | |
return "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.", None | |
update_status(f"4. [{generator_name}] 诪讞讜诇诇 转砖讜讘讛 住讜驻讬转 诪-{len(context_documents)} 拽讟注讬 讛拽砖专...") | |
start_gen_time = time.time() | |
try: | |
full_response = []; error_msg = None | |
generator = openai_service.generate_openai_stream(messages=history, context_documents=context_documents) | |
async for chunk in generator: | |
if isinstance(chunk, str) and chunk.strip().startswith("--- Error:"): | |
if not error_msg: error_msg = chunk.strip() | |
print(f"OpenAI stream yielded error: {chunk.strip()}"); break | |
elif isinstance(chunk, str): full_response.append(chunk); stream_callback(chunk) | |
final_response_text = "".join(full_response); gen_time = time.time() - start_gen_time | |
if error_msg: update_status(f"4. 砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转."); return final_response_text, error_msg | |
else: update_status(f"4. 讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讛讜砖诇诪讛 讘-{gen_time:.2f} 砖谞讬讜转."); return final_response_text, None | |
except Exception as gen_err: | |
gen_time = time.time() - start_gen_time; error_msg_critical = f"--- Error: Critical failure during {generator_name} generation ({type(gen_err).__name__}): {gen_err} ---" | |
update_status(f"4. 砖讙讬讗讛 拽专讬讟讬转 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转."); traceback.print_exc(); return "", error_msg_critical | |
# --- Main Pipeline Orchestrator --- | |
async def execute_validate_generate_pipeline( | |
history: List[Dict], params: Dict[str, Any], status_callback: StatusCallback, stream_callback: Callable[[str], None] | |
) -> Dict[str, Any]: | |
""" | |
Orchestrates Retrieve -> Validate (GPT-4o) -> Generate (GPT-4o) pipeline. | |
Stores both full validated docs and simplified docs for generator input. | |
""" | |
result = { "final_response": "", "validated_documents_full": [], "generator_input_documents": [], "status_log": [], "error": None, "pipeline_used": PIPELINE_VALIDATE_GENERATE_GPT4O } | |
# --- Corrected Initialization --- | |
status_log_internal = [] # Initialize the list on its own line | |
# Define the helper function on the next line, correctly indented | |
def update_status_and_log(message: str): | |
print(f"Status Update: {message}") # Log status to console | |
status_log_internal.append(message) | |
status_callback(message) # Update UI | |
# ------------------------------ | |
current_query_text = "" | |
if history and isinstance(history, list): | |
for msg_ in reversed(history): | |
if isinstance(msg_, dict) and msg_.get("role") == "user": current_query_text = str(msg_.get("content") or ""); break | |
if not current_query_text: print("Error: Could not extract query."); result["error"] = "诇讗 讝讜讛转讛 砖讗诇讛."; result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"; result["status_log"] = status_log_internal; return result | |
try: | |
# --- 1. Retrieval --- | |
retrieved_docs = await run_retrieval_step(current_query_text, params['n_retrieve'], update_status_and_log) | |
if not retrieved_docs: result["error"] = "诇讗 讗讜转专讜 诪拽讜专讜转."; result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"; result["status_log"] = status_log_internal; return result | |
# --- 2. Validation --- | |
validated_docs_full = await run_gpt4o_validation_filter_step(retrieved_docs, current_query_text, params['n_validate'], update_status_and_log) | |
result["validated_documents_full"] = validated_docs_full # Store full docs for trace/debug | |
if not validated_docs_full: result["error"] = "诇讗 谞诪爪讗讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转."; result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"; result["status_log"] = status_log_internal; update_status_and_log(f"4. {result['error']} 诇讗 谞讬转谉 诇讛诪砖讬讱."); return result | |
# --- Simplify Docs for Generation --- | |
simplified_docs_for_generation = [] | |
print(f"Processor: Simplifying {len(validated_docs_full)} docs..."); | |
for doc in validated_docs_full: | |
if isinstance(doc, dict): | |
hebrew_text = doc.get('hebrew_text', '') | |
if hebrew_text: | |
simplified_doc = {'hebrew_text': hebrew_text, 'original_id': doc.get('original_id', 'unknown'), 'source_name': doc.get('source_name', '')} | |
if not simplified_doc['source_name']: del simplified_doc['source_name'] | |
simplified_docs_for_generation.append(simplified_doc) | |
else: print(f"Warn: Skipping non-dict item: {doc}") | |
result["generator_input_documents"] = simplified_docs_for_generation # Store simplified for UI | |
print(f"Processor: Created {len(simplified_docs_for_generation)} simplified docs.") | |
# --- 3. Generation --- | |
final_response_text, generation_error = await run_openai_generation_step( | |
history=history, context_documents=simplified_docs_for_generation, # Pass simplified list | |
update_status=update_status_and_log, stream_callback=stream_callback | |
) | |
result["final_response"] = final_response_text; result["error"] = generation_error | |
# Handle display of potential errors | |
if generation_error and not result["final_response"].strip().startswith(("<div", "诇讗 住讜驻拽讜")): | |
result["final_response"] = f"<div class='rtl-text'><strong>砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛.</strong><br>驻专讟讬诐: {generation_error}<br>---<br>{result['final_response']}</div>" | |
elif result["final_response"] == "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.": | |
result["final_response"] = f"<div class='rtl-text'>{result['final_response']}</div>" | |
except Exception as e: # General catch-all | |
error_type = type(e).__name__; error_msg = f"砖讙讬讗讛 拽专讬讟讬转 RAG ({error_type}): {e}"; print(f"Critical RAG Error: {error_msg}"); traceback.print_exc() | |
result["error"] = error_msg; result["final_response"] = f"<div class='rtl-text'><strong>砖讙讬讗讛 拽专讬讟讬转! ({error_type})</strong><br>谞住讛 砖讜讘.<details><summary>驻专讟讬诐</summary><pre>{traceback.format_exc()}</pre></details></div>" | |
update_status_and_log(f"砖讙讬讗讛 拽专讬讟讬转: {error_type}") | |
result["status_log"] = status_log_internal | |
return result |