april28ragdivreyyoel / rag_processor.py
ABE101's picture
Upload 5 files
ae4184d verified
import time
import asyncio
import traceback
from typing import List, Dict, Any, Optional, Callable, Tuple
from langsmith import traceable
try:
import config
from services import retriever, openai_service
except ImportError:
print("Error: Failed to import config or services in rag_processor.py")
raise SystemExit("Failed imports in rag_processor.py")
PIPELINE_VALIDATE_GENERATE_GPT4O = "GPT-4o Validator + GPT-4o Synthesizer"
StatusCallback = Callable[[str], None]
# --- Step Functions ---
@traceable(name="rag-step-retrieve")
async def run_retrieval_step(query: str, n_retrieve: int, update_status: StatusCallback) -> List[Dict]:
update_status(f"1. 诪讗讞讝专 注讚 {n_retrieve} 驻住拽讗讜转 诪-Pinecone...")
start_time = time.time()
retrieved_docs = retriever.retrieve_documents(query_text=query, n_results=n_retrieve)
retrieval_time = time.time() - start_time
status_msg = f"讗讜讞讝专讜 {len(retrieved_docs)} 驻住拽讗讜转 讘-{retrieval_time:.2f} 砖谞讬讜转."
update_status(f"1. {status_msg}")
if not retrieved_docs:
update_status("1. 诇讗 讗讜转专讜 诪住诪讻讬诐.")
return retrieved_docs
@traceable(name="rag-step-gpt4o-filter")
async def run_gpt4o_validation_filter_step(
docs_to_process: List[Dict], query: str, n_validate: int, update_status: StatusCallback
) -> List[Dict]:
if not docs_to_process:
update_status("2. [GPT-4o] 讚讬诇讜讙 注诇 讗讬诪讜转 - 讗讬谉 驻住拽讗讜转.")
return []
validation_count = min(len(docs_to_process), n_validate)
update_status(f"2. [GPT-4o] 诪转讞讬诇 讗讬诪讜转 诪拽讘讬诇讬 ({validation_count} / {len(docs_to_process)} 驻住拽讗讜转)...")
validation_start_time = time.time()
tasks = [openai_service.validate_relevance_openai(doc, query, i)
for i, doc in enumerate(docs_to_process[:validation_count])]
validation_results = await asyncio.gather(*tasks, return_exceptions=True)
passed_docs = []
passed_count = failed_validation_count = error_count = 0
update_status("3. [GPT-4o] 住讬谞讜谉 驻住拽讗讜转 诇驻讬 转讜爪讗讜转 讗讬诪讜转...")
for i, res in enumerate(validation_results):
original_doc = docs_to_process[i]
if isinstance(res, Exception):
print(f"GPT-4o Validation Exception doc {i}: {res}")
error_count += 1
elif isinstance(res, dict) and 'validation' in res:
if res['validation'].get('contains_relevant_info'):
original_doc['validation_result'] = res['validation']
passed_docs.append(original_doc)
passed_count += 1
else:
failed_validation_count += 1
else:
print(f"GPT-4o Validation Unexpected result doc {i}: {type(res)}")
error_count += 1
validation_time = time.time() - validation_start_time
status_msg_val = (f"讗讬诪讜转 GPT-4o 讛讜砖诇诐 ({passed_count} 注讘专讜, "
f"{failed_validation_count} 谞讚讞讜, {error_count} 砖讙讬讗讜转) "
f"讘-{validation_time:.2f} 砖谞讬讜转.")
update_status(f"2. {status_msg_val}")
status_msg_filter = f"谞讗住驻讜 {len(passed_docs)} 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讗讞专 讗讬诪讜转 GPT-4o."
update_status(f"3. {status_msg_filter}")
return passed_docs
@traceable(name="rag-step-openai-generate")
async def run_openai_generation_step(
history: List[Dict], context_documents: List[Dict],
update_status: StatusCallback, stream_callback: Callable[[str], None]
) -> Tuple[str, Optional[str]]:
generator_name = "OpenAI"
if not context_documents:
update_status(f"4. [{generator_name}] 讚讬诇讜讙 注诇 讬爪讬专讛 - 讗讬谉 驻住拽讗讜转 诇讛拽砖专.")
return "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.", None
update_status(f"4. [{generator_name}] 诪讞讜诇诇 转砖讜讘讛 住讜驻讬转 诪-{len(context_documents)} 拽讟注讬 讛拽砖专...")
start_gen_time = time.time()
try:
full_response = []
error_msg = None
generator = openai_service.generate_openai_stream(
messages=history, context_documents=context_documents
)
async for chunk in generator:
if isinstance(chunk, str) and chunk.strip().startswith("--- Error:"):
if not error_msg:
error_msg = chunk.strip()
print(f"OpenAI stream yielded error: {chunk.strip()}")
break
if isinstance(chunk, str):
full_response.append(chunk)
stream_callback(chunk)
final_response_text = "".join(full_response)
gen_time = time.time() - start_gen_time
if error_msg:
update_status(f"4. 砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转.")
return final_response_text, error_msg
update_status(f"4. 讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讛讜砖诇诪讛 讘-{gen_time:.2f} 砖谞讬讜转.")
return final_response_text, None
except Exception as gen_err:
gen_time = time.time() - start_gen_time
error_msg_critical = (f"--- Error: Critical failure during {generator_name} generation "
f"({type(gen_err).__name__}): {gen_err} ---")
update_status(f"4. 砖讙讬讗讛 拽专讬讟讬转 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转.")
traceback.print_exc()
return "", error_msg_critical
@traceable(name="rag-execute-validate-generate-gpt4o-pipeline")
async def execute_validate_generate_pipeline(
history: List[Dict], params: Dict[str, Any],
status_callback: StatusCallback, stream_callback: Callable[[str], None]
) -> Dict[str, Any]:
result: Dict[str, Any] = {
"final_response": "",
"validated_documents_full": [],
"generator_input_documents": [],
"status_log": [],
"error": None,
"pipeline_used": PIPELINE_VALIDATE_GENERATE_GPT4O
}
status_log_internal: List[str] = []
def update_status_and_log(message: str):
print(f"Status Update: {message}")
status_log_internal.append(message)
status_callback(message)
current_query_text = ""
if history and isinstance(history, list):
for msg_ in reversed(history):
if isinstance(msg_, dict) and msg_.get("role") == "user":
current_query_text = str(msg_.get("content") or "")
break
if not current_query_text:
result["error"] = "诇讗 讝讜讛转讛 砖讗诇讛."
result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"
result["status_log"] = status_log_internal
return result
try:
# 1. Retrieval
retrieved_docs = await run_retrieval_step(
current_query_text, params['n_retrieve'], update_status_and_log
)
if not retrieved_docs:
result["error"] = "诇讗 讗讜转专讜 诪拽讜专讜转."
result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"
result["status_log"] = status_log_internal
return result
# 2. Validation
validated_docs_full = await run_gpt4o_validation_filter_step(
retrieved_docs, current_query_text, params['n_validate'], update_status_and_log
)
result["validated_documents_full"] = validated_docs_full
if not validated_docs_full:
result["error"] = "诇讗 谞诪爪讗讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转."
result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>"
update_status_and_log(f"4. {result['error']} 诇讗 谞讬转谉 诇讛诪砖讬讱.")
return result
# --- Simplify Docs for Generation ---
simplified_docs_for_generation: List[Dict[str, Any]] = []
print(f"Processor: Simplifying {len(validated_docs_full)} docs...")
for doc in validated_docs_full:
if isinstance(doc, dict):
hebrew_text = doc.get('hebrew_text', '')
validation = doc.get('validation_result')
if hebrew_text:
simplified_doc: Dict[str, Any] = {
'hebrew_text': hebrew_text,
'original_id': doc.get('original_id', 'unknown')
}
if doc.get('source_name'):
simplified_doc['source_name'] = doc.get('source_name')
if validation is not None:
simplified_doc['validation_result'] = validation # include judgment
simplified_docs_for_generation.append(simplified_doc)
else:
print(f"Warn: Skipping non-dict item: {doc}")
result["generator_input_documents"] = simplified_docs_for_generation
print(f"Processor: Created {len(simplified_docs_for_generation)} simplified docs with validation results.")
# 3. Generation
final_response_text, generation_error = await run_openai_generation_step(
history=history,
context_documents=simplified_docs_for_generation,
update_status=update_status_and_log,
stream_callback=stream_callback
)
result["final_response"] = final_response_text
result["error"] = generation_error
if generation_error and not result["final_response"].strip().startswith(("<div", "诇讗 住讜驻拽讜")):
result["final_response"] = (
f"<div class='rtl-text'><strong>砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛.</strong><br>"
f"驻专讟讬诐: {generation_error}<br>---<br>{result['final_response']}</div>"
)
elif result["final_response"] == "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.":
result["final_response"] = f"<div class='rtl-text'>{result['final_response']}</div>"
except Exception as e:
error_type = type(e).__name__
error_msg = f"砖讙讬讗讛 拽专讬讟讬转 RAG ({error_type}): {e}"
print(f"Critical RAG Error: {error_msg}")
traceback.print_exc()
result["error"] = error_msg
result["final_response"] = (
f"<div class='rtl-text'><strong>砖讙讬讗讛 拽专讬讟讬转! ({error_type})</strong><br>谞住讛 砖讜讘."
f"<details><summary>驻专讟讬诐</summary><pre>{traceback.format_exc()}</pre></details></div>"
)
update_status_and_log(f"砖讙讬讗讛 拽专讬讟讬转: {error_type}")
result["status_log"] = status_log_internal
return result