april27divreyyoel / rag_processor.py
ABE101's picture
Upload 11 files
1835c79 verified
raw
history blame contribute delete
10.3 kB
# rag_processor.py (Fixed Syntax Error AGAIN)
import time
import asyncio
import traceback
from typing import List, Dict, Any, Optional, Callable, Tuple
from langsmith import traceable
try:
import config
from services import retriever, openai_service
except ImportError:
print("Error: Failed to import config or services in rag_processor.py")
raise SystemExit("Failed imports in rag_processor.py")
PIPELINE_VALIDATE_GENERATE_GPT4O = "GPT-4o Validator + GPT-4o Synthesizer"
StatusCallback = Callable[[str], None]
# --- Step Functions ---
@traceable(name="rag-step-retrieve")
async def run_retrieval_step(query: str, n_retrieve: int, update_status: StatusCallback) -> List[Dict]:
update_status(f"1. 诪讗讞讝专 注讚 {n_retrieve} 驻住拽讗讜转 诪-Pinecone...")
start_time = time.time(); retrieved_docs = retriever.retrieve_documents(query_text=query, n_results=n_retrieve)
retrieval_time = time.time() - start_time; status_msg = f"讗讜讞讝专讜 {len(retrieved_docs)} 驻住拽讗讜转 讘-{retrieval_time:.2f} 砖谞讬讜转."
update_status(f"1. {status_msg}")
if not retrieved_docs: update_status("1. 诇讗 讗讜转专讜 诪住诪讻讬诐.")
return retrieved_docs
@traceable(name="rag-step-gpt4o-filter")
async def run_gpt4o_validation_filter_step(
docs_to_process: List[Dict], query: str, n_validate: int, update_status: StatusCallback
) -> List[Dict]:
if not docs_to_process: update_status("2. [GPT-4o] 讚讬诇讜讙 注诇 讗讬诪讜转 - 讗讬谉 驻住拽讗讜转."); return []
validation_count = min(len(docs_to_process), n_validate)
update_status(f"2. [GPT-4o] 诪转讞讬诇 讗讬诪讜转 诪拽讘讬诇讬 ({validation_count} / {len(docs_to_process)} 驻住拽讗讜转)...")
validation_start_time = time.time()
tasks = [openai_service.validate_relevance_openai(doc, query, i) for i, doc in enumerate(docs_to_process[:validation_count])]
validation_results = await asyncio.gather(*tasks, return_exceptions=True)
passed_docs = []; passed_count, failed_validation_count, error_count = 0, 0, 0
update_status("3. [GPT-4o] 住讬谞讜谉 驻住拽讗讜转 诇驻讬 转讜爪讗讜转 讗讬诪讜转...")
for i, res in enumerate(validation_results):
original_doc = docs_to_process[i]
if isinstance(res, Exception): print(f"GPT-4o Validation Exception doc {i}: {res}"); error_count += 1
elif isinstance(res, dict) and 'validation' in res and 'paragraph_data' in res:
if res["validation"].get("contains_relevant_info") is True:
original_doc['validation_result'] = res["validation"]; passed_docs.append(original_doc); passed_count += 1
else: failed_validation_count += 1
else: print(f"GPT-4o Validation Unexpected result doc {i}: {type(res)}"); error_count += 1
validation_time = time.time() - validation_start_time
status_msg_val = f"讗讬诪讜转 GPT-4o 讛讜砖诇诐 ({passed_count} 注讘专讜, {failed_validation_count} 谞讚讞讜, {error_count} 砖讙讬讗讜转) 讘-{validation_time:.2f} 砖谞讬讜转."
update_status(f"2. {status_msg_val}")
status_msg_filter = f"谞讗住驻讜 {len(passed_docs)} 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讗讞专 讗讬诪讜转 GPT-4o."
update_status(f"3. {status_msg_filter}")
return passed_docs # Returns full docs that passed
@traceable(name="rag-step-openai-generate")
async def run_openai_generation_step(
history: List[Dict], context_documents: List[Dict], update_status: StatusCallback, stream_callback: Callable[[str], None]
) -> Tuple[str, Optional[str]]:
generator_name = "OpenAI";
if not context_documents:
update_status(f"4. [{generator_name}] 讚讬诇讜讙 注诇 讬爪讬专讛 - 讗讬谉 驻住拽讗讜转 诇讛拽砖专.")
return "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.", None
update_status(f"4. [{generator_name}] 诪讞讜诇诇 转砖讜讘讛 住讜驻讬转 诪-{len(context_documents)} 拽讟注讬 讛拽砖专...")
start_gen_time = time.time()
try:
full_response = []; error_msg = None
generator = openai_service.generate_openai_stream(messages=history, context_documents=context_documents)
async for chunk in generator:
if isinstance(chunk, str) and chunk.strip().startswith("--- Error:"):
if not error_msg: error_msg = chunk.strip()
print(f"OpenAI stream yielded error: {chunk.strip()}"); break
elif isinstance(chunk, str): full_response.append(chunk); stream_callback(chunk)
final_response_text = "".join(full_response); gen_time = time.time() - start_gen_time
if error_msg: update_status(f"4. 砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转."); return final_response_text, error_msg
else: update_status(f"4. 讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讛讜砖诇诪讛 讘-{gen_time:.2f} 砖谞讬讜转."); return final_response_text, None
except Exception as gen_err:
gen_time = time.time() - start_gen_time; error_msg_critical = f"--- Error: Critical failure during {generator_name} generation ({type(gen_err).__name__}): {gen_err} ---"
update_status(f"4. 砖讙讬讗讛 拽专讬讟讬转 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转."); traceback.print_exc(); return "", error_msg_critical
# --- Main Pipeline Orchestrator ---
@traceable(name="rag-execute-validate-generate-gpt4o-pipeline")
async def execute_validate_generate_pipeline(
history: List[Dict], params: Dict[str, Any], status_callback: StatusCallback, stream_callback: Callable[[str], None]
) -> Dict[str, Any]:
"""
Orchestrates Retrieve -> Validate (GPT-4o) -> Generate (GPT-4o) pipeline.
Stores both full validated docs and simplified docs for generator input.
"""
result = { "final_response": "", "validated_documents_full": [], "generator_input_documents": [], "status_log": [], "error": None, "pipeline_used": PIPELINE_VALIDATE_GENERATE_GPT4O }
# --- Corrected Initialization ---
status_log_internal = [] # Initialize the list on its own line
# Define the helper function on the next line, correctly indented
def update_status_and_log(message: str):
print(f"Status Update: {message}") # Log status to console
status_log_internal.append(message)
status_callback(message) # Update UI
# ------------------------------
current_query_text = ""
if history and isinstance(history, list):
for msg_ in reversed(history):
if isinstance(msg_, dict) and msg_.get("role") == "user": current_query_text = str(msg_.get("content") or ""); break
if not current_query_text: print("Error: Could not extract query."); result["error"] = "诇讗 讝讜讛转讛 砖讗诇讛."; result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"; result["status_log"] = status_log_internal; return result
try:
# --- 1. Retrieval ---
retrieved_docs = await run_retrieval_step(current_query_text, params['n_retrieve'], update_status_and_log)
if not retrieved_docs: result["error"] = "诇讗 讗讜转专讜 诪拽讜专讜转."; result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"; result["status_log"] = status_log_internal; return result
# --- 2. Validation ---
validated_docs_full = await run_gpt4o_validation_filter_step(retrieved_docs, current_query_text, params['n_validate'], update_status_and_log)
result["validated_documents_full"] = validated_docs_full # Store full docs for trace/debug
if not validated_docs_full: result["error"] = "诇讗 谞诪爪讗讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转."; result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"; result["status_log"] = status_log_internal; update_status_and_log(f"4. {result['error']} 诇讗 谞讬转谉 诇讛诪砖讬讱."); return result
# --- Simplify Docs for Generation ---
simplified_docs_for_generation = []
print(f"Processor: Simplifying {len(validated_docs_full)} docs...");
for doc in validated_docs_full:
if isinstance(doc, dict):
hebrew_text = doc.get('hebrew_text', '')
if hebrew_text:
simplified_doc = {'hebrew_text': hebrew_text, 'original_id': doc.get('original_id', 'unknown'), 'source_name': doc.get('source_name', '')}
if not simplified_doc['source_name']: del simplified_doc['source_name']
simplified_docs_for_generation.append(simplified_doc)
else: print(f"Warn: Skipping non-dict item: {doc}")
result["generator_input_documents"] = simplified_docs_for_generation # Store simplified for UI
print(f"Processor: Created {len(simplified_docs_for_generation)} simplified docs.")
# --- 3. Generation ---
final_response_text, generation_error = await run_openai_generation_step(
history=history, context_documents=simplified_docs_for_generation, # Pass simplified list
update_status=update_status_and_log, stream_callback=stream_callback
)
result["final_response"] = final_response_text; result["error"] = generation_error
# Handle display of potential errors
if generation_error and not result["final_response"].strip().startswith(("<div", "诇讗 住讜驻拽讜")):
result["final_response"] = f"<div class='rtl-text'><strong>砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛.</strong><br>驻专讟讬诐: {generation_error}<br>---<br>{result['final_response']}</div>"
elif result["final_response"] == "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.":
result["final_response"] = f"<div class='rtl-text'>{result['final_response']}</div>"
except Exception as e: # General catch-all
error_type = type(e).__name__; error_msg = f"砖讙讬讗讛 拽专讬讟讬转 RAG ({error_type}): {e}"; print(f"Critical RAG Error: {error_msg}"); traceback.print_exc()
result["error"] = error_msg; result["final_response"] = f"<div class='rtl-text'><strong>砖讙讬讗讛 拽专讬讟讬转! ({error_type})</strong><br>谞住讛 砖讜讘.<details><summary>驻专讟讬诐</summary><pre>{traceback.format_exc()}</pre></details></div>"
update_status_and_log(f"砖讙讬讗讛 拽专讬讟讬转: {error_type}")
result["status_log"] = status_log_internal
return result