| | """FastAPI backend service for RAG application.""" |
| | from fastapi import FastAPI, HTTPException, BackgroundTasks |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | from typing import List, Optional, Dict |
| | import uvicorn |
| | from datetime import datetime |
| | import os |
| |
|
| | from config import settings |
| | from dataset_loader import RAGBenchLoader |
| | from vector_store import ChromaDBManager |
| | from llm_client import GroqLLMClient, RAGPipeline |
| | from trace_evaluator import TRACEEvaluator |
| |
|
| | |
| | app = FastAPI( |
| | title="RAG Capstone API", |
| | description="API for RAG system with TRACE evaluation", |
| | version="1.0.0" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | rag_pipeline: Optional[RAGPipeline] = None |
| | vector_store: Optional[ChromaDBManager] = None |
| | current_collection: Optional[str] = None |
| |
|
| |
|
| | |
| | class DatasetLoadRequest(BaseModel): |
| | """Request model for loading dataset.""" |
| | dataset_name: str = Field(..., description="Name of the dataset") |
| | num_samples: int = Field(50, description="Number of samples to load") |
| | chunking_strategy: str = Field("hybrid", description="Chunking strategy") |
| | chunk_size: int = Field(512, description="Size of chunks") |
| | overlap: int = Field(50, description="Overlap between chunks") |
| | embedding_model: str = Field(..., description="Embedding model name") |
| | llm_model: str = Field("llama-3.1-8b-instant", description="LLM model name") |
| | groq_api_key: str = Field(..., description="Groq API key") |
| |
|
| |
|
| | class QueryRequest(BaseModel): |
| | """Request model for querying.""" |
| | query: str = Field(..., description="User query") |
| | n_results: int = Field(5, description="Number of documents to retrieve") |
| | max_tokens: int = Field(1024, description="Maximum tokens to generate") |
| | temperature: float = Field(0.7, description="Sampling temperature") |
| |
|
| |
|
| | class QueryResponse(BaseModel): |
| | """Response model for query.""" |
| | query: str |
| | response: str |
| | retrieved_documents: List[Dict] |
| | timestamp: str |
| |
|
| |
|
| | class EvaluationRequest(BaseModel): |
| | """Request model for evaluation.""" |
| | num_samples: int = Field(10, description="Number of test samples") |
| |
|
| |
|
| | class CollectionInfo(BaseModel): |
| | """Collection information model.""" |
| | name: str |
| | count: int |
| | metadata: Dict |
| |
|
| |
|
| | |
| | @app.get("/") |
| | async def root(): |
| | """Root endpoint.""" |
| | return { |
| | "message": "RAG Capstone API", |
| | "version": "1.0.0", |
| | "docs": "/docs" |
| | } |
| |
|
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | """Health check endpoint.""" |
| | return { |
| | "status": "healthy", |
| | "timestamp": datetime.now().isoformat() |
| | } |
| |
|
| |
|
| | @app.get("/datasets") |
| | async def list_datasets(): |
| | """List available datasets.""" |
| | return { |
| | "datasets": settings.ragbench_datasets |
| | } |
| |
|
| |
|
| | @app.get("/models/embedding") |
| | async def list_embedding_models(): |
| | """List available embedding models.""" |
| | return { |
| | "embedding_models": settings.embedding_models |
| | } |
| |
|
| |
|
| | @app.get("/models/llm") |
| | async def list_llm_models(): |
| | """List available LLM models.""" |
| | return { |
| | "llm_models": settings.llm_models |
| | } |
| |
|
| |
|
| | @app.get("/chunking-strategies") |
| | async def list_chunking_strategies(): |
| | """List available chunking strategies.""" |
| | return { |
| | "chunking_strategies": settings.chunking_strategies |
| | } |
| |
|
| |
|
| | @app.get("/collections") |
| | async def list_collections(): |
| | """List all vector store collections.""" |
| | global vector_store |
| | |
| | if not vector_store: |
| | vector_store = ChromaDBManager(settings.chroma_persist_directory) |
| | |
| | collections = vector_store.list_collections() |
| | |
| | return { |
| | "collections": collections, |
| | "count": len(collections) |
| | } |
| |
|
| |
|
| | @app.get("/collections/{collection_name}") |
| | async def get_collection_info(collection_name: str): |
| | """Get information about a specific collection.""" |
| | global vector_store |
| | |
| | if not vector_store: |
| | vector_store = ChromaDBManager(settings.chroma_persist_directory) |
| | |
| | try: |
| | stats = vector_store.get_collection_stats(collection_name) |
| | return stats |
| | except Exception as e: |
| | raise HTTPException(status_code=404, detail=f"Collection not found: {str(e)}") |
| |
|
| |
|
| | @app.post("/load-dataset") |
| | async def load_dataset(request: DatasetLoadRequest, background_tasks: BackgroundTasks): |
| | """Load dataset and create vector collection.""" |
| | global rag_pipeline, vector_store, current_collection |
| | |
| | try: |
| | |
| | loader = RAGBenchLoader() |
| | |
| | |
| | dataset = loader.load_dataset( |
| | request.dataset_name, |
| | split="train", |
| | max_samples=request.num_samples |
| | ) |
| | |
| | if not dataset: |
| | raise HTTPException(status_code=400, detail="Failed to load dataset") |
| | |
| | |
| | vector_store = ChromaDBManager(settings.chroma_persist_directory) |
| | |
| | |
| | collection_name = f"{request.dataset_name}_{request.chunking_strategy}_{request.embedding_model.split('/')[-1]}" |
| | collection_name = collection_name.replace("-", "_").replace(".", "_") |
| | |
| | |
| | vector_store.load_dataset_into_collection( |
| | collection_name=collection_name, |
| | embedding_model_name=request.embedding_model, |
| | chunking_strategy=request.chunking_strategy, |
| | dataset_data=dataset, |
| | chunk_size=request.chunk_size, |
| | overlap=request.overlap |
| | ) |
| | |
| | |
| | llm_client = GroqLLMClient( |
| | api_key=request.groq_api_key, |
| | model_name=request.llm_model, |
| | max_rpm=settings.groq_rpm_limit, |
| | rate_limit_delay=settings.rate_limit_delay |
| | ) |
| | |
| | |
| | rag_pipeline = RAGPipeline(llm_client, vector_store) |
| | current_collection = collection_name |
| | |
| | return { |
| | "status": "success", |
| | "collection_name": collection_name, |
| | "num_documents": len(dataset), |
| | "message": f"Collection '{collection_name}' created successfully" |
| | } |
| | |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error loading dataset: {str(e)}") |
| |
|
| |
|
| | @app.post("/query", response_model=QueryResponse) |
| | async def query_rag(request: QueryRequest): |
| | """Query the RAG system.""" |
| | global rag_pipeline |
| | |
| | if not rag_pipeline: |
| | raise HTTPException( |
| | status_code=400, |
| | detail="RAG pipeline not initialized. Load a dataset first." |
| | ) |
| | |
| | try: |
| | result = rag_pipeline.query( |
| | query=request.query, |
| | n_results=request.n_results, |
| | max_tokens=request.max_tokens, |
| | temperature=request.temperature |
| | ) |
| | |
| | result["timestamp"] = datetime.now().isoformat() |
| | |
| | return result |
| | |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") |
| |
|
| |
|
| | @app.get("/chat-history") |
| | async def get_chat_history(): |
| | """Get chat history.""" |
| | global rag_pipeline |
| | |
| | if not rag_pipeline: |
| | raise HTTPException( |
| | status_code=400, |
| | detail="RAG pipeline not initialized. Load a dataset first." |
| | ) |
| | |
| | return { |
| | "history": rag_pipeline.get_chat_history() |
| | } |
| |
|
| |
|
| | @app.delete("/chat-history") |
| | async def clear_chat_history(): |
| | """Clear chat history.""" |
| | global rag_pipeline |
| | |
| | if not rag_pipeline: |
| | raise HTTPException( |
| | status_code=400, |
| | detail="RAG pipeline not initialized. Load a dataset first." |
| | ) |
| | |
| | rag_pipeline.clear_history() |
| | |
| | return { |
| | "status": "success", |
| | "message": "Chat history cleared" |
| | } |
| |
|
| |
|
| | @app.post("/evaluate") |
| | async def run_evaluation(request: EvaluationRequest): |
| | """Run TRACE evaluation.""" |
| | global rag_pipeline, current_collection |
| | |
| | if not rag_pipeline: |
| | raise HTTPException( |
| | status_code=400, |
| | detail="RAG pipeline not initialized. Load a dataset first." |
| | ) |
| | |
| | try: |
| | |
| | collection_metadata = vector_store.current_collection.metadata |
| | dataset_name = current_collection.split("_")[0] if current_collection else "hotpotqa" |
| | |
| | |
| | loader = RAGBenchLoader() |
| | test_data = loader.get_test_data(dataset_name, request.num_samples) |
| | |
| | |
| | test_cases = [] |
| | |
| | for sample in test_data: |
| | result = rag_pipeline.query(sample["question"], n_results=5) |
| | |
| | test_cases.append({ |
| | "query": sample["question"], |
| | "response": result["response"], |
| | "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], |
| | "ground_truth": sample.get("answer", "") |
| | }) |
| | |
| | |
| | evaluator = TRACEEvaluator() |
| | results = evaluator.evaluate_batch(test_cases) |
| | |
| | return { |
| | "status": "success", |
| | "results": results |
| | } |
| | |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error during evaluation: {str(e)}") |
| |
|
| |
|
| | @app.delete("/collections/{collection_name}") |
| | async def delete_collection(collection_name: str): |
| | """Delete a collection.""" |
| | global vector_store |
| | |
| | if not vector_store: |
| | vector_store = ChromaDBManager(settings.chroma_persist_directory) |
| | |
| | try: |
| | vector_store.delete_collection(collection_name) |
| | return { |
| | "status": "success", |
| | "message": f"Collection '{collection_name}' deleted" |
| | } |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}") |
| |
|
| |
|
| | @app.get("/current-collection") |
| | async def get_current_collection(): |
| | """Get current collection information.""" |
| | global current_collection, vector_store |
| | |
| | if not current_collection: |
| | return { |
| | "collection": None, |
| | "message": "No collection loaded" |
| | } |
| | |
| | try: |
| | stats = vector_store.get_collection_stats(current_collection) |
| | return { |
| | "collection": current_collection, |
| | "stats": stats |
| | } |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run( |
| | "api:app", |
| | host="0.0.0.0", |
| | port=8000, |
| | reload=True |
| | ) |
| |
|