Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import tempfile | |
| import requests | |
| from fastapi import FastAPI, HTTPException, Depends, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Union, Any, Optional | |
| from dotenv import load_dotenv | |
| import asyncio | |
| import httpx | |
| import time | |
| from urllib.parse import urlparse, unquote | |
| import uuid | |
| import re | |
| # Import LangChain Document and text splitter | |
| from langchain_core.documents import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from processing_utility import ( | |
| extract_schema_from_file, | |
| process_document, | |
| download_and_parse_document_using_llama_index, | |
| ) | |
| # Import the new classes and functions from rag_utils | |
| from rag_utils import ( | |
| process_markdown_with_recursive_chunking, | |
| generate_answer_with_groq, | |
| generate_hypothetical_document, | |
| HybridSearchManager, | |
| EmbeddingClient, | |
| CHUNK_SIZE, | |
| CHUNK_OVERLAP, | |
| TOP_K_CHUNKS, | |
| GROQ_MODEL_NAME, | |
| ) | |
| load_dotenv() | |
| # --- FastAPI App Initialization --- | |
| app = FastAPI( | |
| title="HackRX RAG API", | |
| description="API for Retrieval-Augmented Generation from PDF documents.", | |
| version="1.0.0", | |
| ) | |
| # --- Global instance for the HybridSearchManager --- | |
| hybrid_search_manager: Optional[HybridSearchManager] = None | |
| async def startup_event(): | |
| global hybrid_search_manager | |
| hybrid_search_manager = HybridSearchManager() | |
| #initialize_llama_extract_agent() | |
| print("Application startup complete. HybridSearchManager is ready.") | |
| # --- Groq API Key Setup --- | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "NOT_FOUND") | |
| if GROQ_API_KEY == "NOT_FOUND": | |
| print( | |
| "WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production." | |
| ) | |
| # --- Pydantic Models for Request and Response --- | |
| class RunRequest(BaseModel): | |
| documents: str | |
| questions: List[str] | |
| class Answer(BaseModel): | |
| answer: str | |
| class RunResponse(BaseModel): | |
| answers: List[str] | |
| #step_timings: Dict[str, float] | |
| #hypothetical_documents: List[str] | |
| async def run_rag_pipeline( | |
| request: RunRequest | |
| ): | |
| """ | |
| Runs the RAG pipeline for a given PDF document (converted to Markdown internally) | |
| and a list of questions. | |
| """ | |
| pdf_url = request.documents | |
| questions = request.questions | |
| local_markdown_path = None | |
| step_timings = {} | |
| start_time_total = time.perf_counter() | |
| try: | |
| if hybrid_search_manager is None: | |
| raise HTTPException( | |
| status_code=500, detail="HybridSearchManager not initialized." | |
| ) | |
| # 1. Parsing: Download PDF and parse to Markdown | |
| start_time = time.perf_counter() | |
| markdown_content = await download_and_parse_document_using_llama_index(pdf_url) | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", delete=False, encoding="utf-8", suffix=".md" | |
| ) as temp_md_file: | |
| temp_md_file.write(markdown_content) | |
| local_markdown_path = temp_md_file.name | |
| end_time = time.perf_counter() | |
| step_timings["parsing_to_markdown"] = end_time - start_time | |
| print( | |
| f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds." | |
| ) | |
| # 2. Chunk Generation: Process Markdown into chunks | |
| start_time = time.perf_counter() | |
| processed_documents = process_markdown_with_recursive_chunking( | |
| local_markdown_path, | |
| CHUNK_SIZE, | |
| CHUNK_OVERLAP, | |
| ) | |
| if not processed_documents: | |
| raise HTTPException( | |
| status_code=500, detail="Failed to process document into chunks." | |
| ) | |
| end_time = time.perf_counter() | |
| step_timings["chunk_generation"] = end_time - start_time | |
| print( | |
| f"Chunk Generation took {step_timings['chunk_generation']:.2f} seconds." | |
| ) | |
| # 3. Model Initialization and Embeddings Pre-computation | |
| start_time = time.perf_counter() | |
| await hybrid_search_manager.initialize_models(processed_documents) | |
| end_time = time.perf_counter() | |
| step_timings["model_initialization"] = end_time - start_time | |
| print( | |
| f"Model initialization took {step_timings['model_initialization']:.2f} seconds." | |
| ) | |
| # --- NEW CONCURRENT WORKFLOW --- | |
| # 4. Concurrently generate all hypothetical documents | |
| start_time_hyde = time.perf_counter() | |
| hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in questions] | |
| all_hyde_docs = await asyncio.gather(*hyde_tasks) | |
| end_time_hyde = time.perf_counter() | |
| step_timings["hyde_generation_total_time"] = end_time_hyde - start_time_hyde | |
| step_timings["hyde_generation_avg_time_per_query"] = (end_time_hyde - start_time_hyde) / len(questions) | |
| # 5. Concurrently perform initial hybrid search to get candidates for ALL queries | |
| start_time_search = time.perf_counter() | |
| candidate_retrieval_tasks = [ | |
| hybrid_search_manager.retrieve_candidates(q, hyde_doc) | |
| for q, hyde_doc in zip(questions, all_hyde_docs) | |
| ] | |
| all_candidates = await asyncio.gather(*candidate_retrieval_tasks) | |
| end_time_search = time.perf_counter() | |
| step_timings["candidate_retrieval_total_time"] = end_time_search - start_time_search | |
| # 6. Concurrently rerank the candidates for ALL queries | |
| start_time_rerank = time.perf_counter() | |
| rerank_tasks = [ | |
| hybrid_search_manager.rerank_results(q, candidates, TOP_K_CHUNKS) | |
| for q, candidates in zip(questions, all_candidates) | |
| ] | |
| reranked_results_and_times = await asyncio.gather(*rerank_tasks) | |
| end_time_rerank = time.perf_counter() | |
| step_timings["reranking_total_time"] = end_time_rerank - start_time_rerank | |
| # Unpack reranked results and timings | |
| all_retrieved_results = [item[0] for item in reranked_results_and_times] | |
| all_rerank_times = [item[1] for item in reranked_results_and_times] | |
| step_timings["reranking_avg_time_per_query"] = (end_time_rerank - start_time_rerank) / len(questions) | |
| # 7. Concurrently generate final answers | |
| start_time_generation = time.perf_counter() | |
| generation_tasks = [] | |
| for question, retrieved_results in zip(questions, all_retrieved_results): | |
| if retrieved_results: | |
| generation_tasks.append( | |
| generate_answer_with_groq( | |
| question, retrieved_results, GROQ_API_KEY | |
| ) | |
| ) | |
| else: | |
| no_info_future = asyncio.Future() | |
| no_info_future.set_result( | |
| "No relevant information found in the document to answer this question." | |
| ) | |
| generation_tasks.append(no_info_future) | |
| all_answer_texts = await asyncio.gather(*generation_tasks) | |
| end_time_generation = time.perf_counter() | |
| step_timings["generation_total_time"] = end_time_generation - start_time_generation | |
| step_timings["generation_avg_time_per_query"] = (end_time_generation - start_time_generation) / len(questions) | |
| end_time_total = time.perf_counter() | |
| total_processing_time = end_time_total - start_time_total | |
| step_timings["total_processing_time"] = total_processing_time | |
| print("All questions processed.") | |
| all_answers = [answer_text for answer_text in all_answer_texts] | |
| return RunResponse( | |
| answers=all_answers, | |
| #step_timings=step_timings, | |
| #hypothetical_documents=all_hyde_docs | |
| ) | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| print(f"An unhandled error occurred: {e}") | |
| raise HTTPException( | |
| status_code=500, detail=f"An internal server error occurred: {e}" | |
| ) | |
| finally: | |
| if local_markdown_path and os.path.exists(local_markdown_path): | |
| os.unlink(local_markdown_path) | |
| print(f"Cleaned up temporary markdown file: {local_markdown_path}") | |