File size: 9,172 Bytes
d339ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d03a2
 
 
d339ddb
 
 
 
 
 
 
ba2516e
d339ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c033f6
d339ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d03a2
d339ddb
 
69b561b
d339ddb
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# File: main.py
# (Modified to load embedding model at startup and await async pipeline run)

import os
import tempfile
import asyncio
import time
from typing import List, Dict, Any
from urllib.parse import urlparse, unquote
from pathlib import Path

import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from groq import AsyncGroq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch # Import torch to check for CUDA availability

from dotenv import load_dotenv

load_dotenv()

# Import the Pipeline class from the previous file
from pipeline_2 import Pipeline

# FastAPI application setup
app = FastAPI(
    title="Llama-Index RAG with Groq",
    description="An API to process a PDF from a URL and answer a list of questions using a Llama-Index RAG pipeline.",
)

# --- Pydantic Models for API Request and Response ---
class RunRequest(BaseModel):
    documents: HttpUrl
    questions: List[str]

class Answer(BaseModel):
    question: str
    answer: str

class RunResponse(BaseModel):
    answers: List[str]
    #processing_time: float
    #step_timings: Dict[str, float]

# --- Global Configurations ---
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "gsk_...")
GROQ_MODEL_NAME = "llama3-70b-8192"
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

# Global variable to hold the initialized embedding model
embed_model_instance: None 

if GROQ_API_KEY == "gsk_...":
    print("WARNING: GROQ_API_KEY is not set. Please set it in your environment or main.py.")

@app.on_event("startup")
async def startup_event():
    """
    Loads the embedding model once when the application starts.
    This prevents re-loading it on every API call.
    """
    global embed_model_instance
    print(f"Loading embedding model '{EMBEDDING_MODEL_NAME}' at startup...")
    # Check for GPU availability and use it if possible
    # Assuming 16GB VRAM, a standard device check is sufficient
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    embed_model_instance = await asyncio.to_thread(HuggingFaceEmbedding, model_name=EMBEDDING_MODEL_NAME, device=device)
    print("Embedding model loaded successfully.")

# --- Async Groq Generation Function ---
async def generate_answer_with_groq(query: str, retrieved_results: List[dict], groq_api_key: str) -> str:
    """
    Generates an answer using the Groq API based on the query and retrieved chunks' content.
    """
    if not groq_api_key:
        return "Error: Groq API key is not set. Cannot generate answer."

    client = AsyncGroq(api_key=groq_api_key)

    context_parts = []
    for i, res in enumerate(retrieved_results):
        content = res.get("content", "")
        metadata = res.get("document_metadata", {})
        
        section_heading = metadata.get("section_heading", metadata.get("file_name", "N/A"))
        
        context_parts.append(
            f"--- Context Chunk {i+1} ---\n"
            f"Document Part: {section_heading}\n"
            f"Content: {content}\n"
            f"-------------------------"
        )
    context = "\n\n".join(context_parts)

    prompt = (
        f"You are a specialized document analyzer assistant. Your task is to answer the user's question "
        f"solely based on the provided context. If the answer cannot be found in the provided context, "
        f"clearly state that you do not have enough information.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {query}\n\n"
        f"Answer:"
    )

    try:
        chat_completion = await client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model=GROQ_MODEL_NAME,
            temperature=0.7,
            max_tokens=500,
        )
        answer = chat_completion.choices[0].message.content
        return answer
    except Exception as e:
        print(f"An error occurred during Groq API call: {e}")
        return "Could not generate an answer due to an API error."


# --- FastAPI Endpoint ---
@app.post("/hackrx/run", response_model=RunResponse)
async def run_rag_pipeline(request: RunRequest):
    """
    Runs the RAG pipeline for a given PDF document URL and a list of questions.
    The PDF is downloaded, processed, and then the questions are answered.
    """
    pdf_url = request.documents
    questions = request.questions
    local_pdf_path = None
    step_timings = {}

    start_time_total = time.perf_counter()

    if not embed_model_instance:
         raise HTTPException(
            status_code=500,
            detail="Embedding model not loaded. Application startup failed."
        )

    if not GROQ_API_KEY or GROQ_API_KEY == "gsk_...":
        raise HTTPException(
            status_code=500,
            detail="Groq API key is not configured. Please set the GROQ_API_KEY environment variable."
        )

    try:
        # 1. Download PDF
        start_time = time.perf_counter()
        async with httpx.AsyncClient() as client:
            try:
                response = await client.get(str(pdf_url), timeout=30.0, follow_redirects=True)
                response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
                doc_bytes = response.content
                print("Download successful.")
            except httpx.HTTPStatusError as e:
                raise HTTPException(status_code=e.response.status_code, detail=f"HTTP error downloading PDF: {e.response.status_code} - {e.response.text}")
            except httpx.RequestError as e:
                raise HTTPException(status_code=400, detail=f"Network error downloading PDF: {e}")
            except Exception as e:
                raise HTTPException(status_code=500, detail=f"An unexpected error occurred during download: {e}")

        # Determine a temporary local filename
        parsed_path = urlparse(str(pdf_url)).path
        filename = unquote(os.path.basename(parsed_path))
        if not filename or not filename.lower().endswith(".pdf"):
            # If the URL doesn't provide a valid PDF filename, create a generic one.
            filename = "downloaded_document.pdf"
        
        # Use tempfile to create a secure temporary file
        with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_pdf_file:
            temp_pdf_file.write(doc_bytes)
            local_pdf_path = temp_pdf_file.name

        end_time = time.perf_counter()
        step_timings["download_pdf"] = end_time - start_time
        print(f"PDF download took {step_timings['download_pdf']:.2f} seconds.")

        # 2. Initialize and Run the Pipeline (Parsing, Node Creation, Embeddings)
        start_time = time.perf_counter()
        # The Pipeline's run() method is now async, so await it directly
        pipeline = Pipeline(groq_api_key=GROQ_API_KEY, pdf_path=local_pdf_path, embed_model=embed_model_instance)
        await pipeline.run() # Changed from asyncio.to_thread(pipeline.run)
        end_time = time.perf_counter()
        step_timings["pipeline_setup"] = end_time - start_time
        print(f"Pipeline setup took {step_timings['pipeline_setup']:.2f} seconds.")

        # 3. Concurrent Retrieval Phase
        start_time_retrieval = time.perf_counter()
        print(f"\nStarting concurrent retrieval for {len(questions)} questions...")
        
        retrieval_tasks = [asyncio.to_thread(pipeline.retrieve_nodes, q) for q in questions]
        all_retrieved_results = await asyncio.gather(*retrieval_tasks)
        
        end_time_retrieval = time.perf_counter()
        step_timings["retrieval"] = end_time_retrieval - start_time_retrieval
        print(f"Retrieval phase completed in {step_timings['retrieval']:.2f} seconds.")

        # 4. Concurrent Generation Phase
        start_time_generation = time.perf_counter()
        print(f"\nStarting concurrent answer generation for {len(questions)} questions...")
        
        generation_tasks = [
            generate_answer_with_groq(q, retrieved_results, GROQ_API_KEY)
            for q, retrieved_results in zip(questions, all_retrieved_results)
        ]

        all_answer_texts = await asyncio.gather(*generation_tasks)
        
        end_time_generation = time.perf_counter()
        step_timings["generation"] = end_time_generation - start_time_generation
        print(f"Generation phase completed in {step_timings['generation']:.2f} seconds.")

        end_time_total = time.perf_counter()
        total_processing_time = end_time_total - start_time_total

        answers = [answer for answer in all_answer_texts]

        return RunResponse(
            answers=answers
        )

    except HTTPException as e:
        raise e
    except Exception as e:
        print(f"An unhandled error occurred: {e}")
        raise HTTPException(
            status_code=500, detail=f"An internal server error occurred: {e}"
        )
    finally:
        if local_pdf_path and os.path.exists(local_pdf_path):
            os.unlink(local_pdf_path)
            print(f"Cleaned up temporary PDF file: {local_pdf_path}")