rohit
seperate prompts
b924bc1
raw
history blame
11 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import logging
import sys
from dotenv import load_dotenv
from .config import DATASET_CONFIGS, load_prompt_template
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
import json
# Load environment variables
load_dotenv()
# Lazy imports to avoid blocking startup
# from .pipeline import RAGPipeline # Will import when needed
# import umap # Will import when needed for visualization
# import plotly.express as px # Will import when needed for visualization
# import plotly.graph_objects as go # Will import when needed for visualization
# from plotly.subplots import make_subplots # Will import when needed for visualization
# import numpy as np # Will import when needed for visualization
# from sklearn.preprocessing import normalize # Will import when needed for visualization
# import pandas as pd # Will import when needed for visualization
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
app = FastAPI(title="RAG Pipeline API", description="Multi-dataset RAG API", version="1.0.0")
# Initialize OpenRouter client
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
if not openrouter_api_key:
raise ValueError("OPENROUTER_API_KEY environment variable is not set")
openrouter_client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=openrouter_api_key
)
# Model configuration
MODEL_NAME = "z-ai/glm-4.5-air:free"
# Initialize pipelines for all datasets
pipelines = {}
google_api_key = os.getenv("GOOGLE_API_KEY")
logger.info(f"Starting RAG Pipeline API")
logger.info(f"Port from env: {os.getenv('PORT', 'Not set - will use 8000')}")
logger.info(f"Google API Key present: {'Yes' if google_api_key else 'No'}")
logger.info(f"Available datasets: {list(DATASET_CONFIGS.keys())}")
# Define tools for the GLM model
def rag_qa(question: str, dataset: str = "developer-portfolio") -> str:
"""
Get answers from the RAG pipeline for specific questions about the dataset.
Args:
question: The question to answer using the RAG pipeline
dataset: The dataset to search in (default: developer-portfolio)
Returns:
Answer from the RAG pipeline
"""
try:
# Check if pipelines are loaded
if not pipelines:
return "RAG Pipeline is running but datasets are still loading in the background. Please try again in a moment."
# Select the appropriate pipeline based on dataset
if dataset not in pipelines:
return f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}"
selected_pipeline = pipelines[dataset]
answer = selected_pipeline.answer_question(question)
return answer
except Exception as e:
return f"Error accessing RAG pipeline: {str(e)}"
# Tool definitions for GLM
TOOLS = [
{
"type": "function",
"function": {
"name": "rag_qa",
"description": "Get answers from the RAG pipeline for specific questions about datasets",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to answer using the RAG pipeline"
},
"dataset": {
"type": "string",
"description": "The dataset to search in (default: developer-portfolio)",
"default": "developer-portfolio"
}
},
"required": ["question"]
}
}
}
]
# Don't load datasets during startup - do it asynchronously after server starts
logger.info("RAG Pipeline API is ready to serve requests - datasets will load in background")
# Visualization function disabled to speed up startup
# def create_3d_visualization(pipeline):
# ... (commented out for faster startup)
class Question(BaseModel):
text: str
dataset: str = "developer-portfolio" # Default dataset
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[ChatMessage]
dataset: str = "developer-portfolio" # Default dataset
@app.post("/chat")
async def chat_with_ai(request: ChatRequest):
"""
Chat with the AI assistant. The AI will use the RAG pipeline when needed to answer questions about the datasets.
"""
try:
# Convert messages to OpenAI format with proper typing
messages: list[ChatCompletionMessageParam] = [
{"role": msg.role, "content": msg.content} # type: ignore
for msg in request.messages
]
# Add system message to guide the AI
if request.dataset == "developer-portfolio":
system_message: ChatCompletionMessageParam = {
"role": "system",
"content": load_prompt_template("system-instruction.txt")
}
else:
system_message: ChatCompletionMessageParam = {
"role": "system",
"content": load_prompt_template("generic-system-instruction.txt")
}
messages.insert(0, system_message)
# Make the API call with tools
response = openrouter_client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
tools=TOOLS, # type: ignore
tool_choice="auto"
)
message = response.choices[0].message
finish_reason = response.choices[0].finish_reason
# Handle tool calls
if finish_reason == "tool_calls" and hasattr(message, 'tool_calls') and message.tool_calls:
tool_results = []
# Execute tool calls
for tool_call in message.tool_calls:
if tool_call.function.name == "rag_qa":
# Parse arguments
args = json.loads(tool_call.function.arguments)
question = args.get("question")
dataset = args.get("dataset", request.dataset)
# Call the rag_qa function
result = rag_qa(question, dataset)
tool_results.append({
"tool_call_id": tool_call.id,
"result": result
})
# Add tool results to conversation and get final response
assistant_message: ChatCompletionMessageParam = {
"role": "assistant",
"content": message.content or "",
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
}
for tc in message.tool_calls
]
}
messages.append(assistant_message)
for tool_result in tool_results:
tool_message: ChatCompletionMessageParam = {
"role": "tool",
"tool_call_id": tool_result["tool_call_id"],
"content": tool_result["result"]
}
messages.append(tool_message)
# Get final response
final_response = openrouter_client.chat.completions.create(
model=MODEL_NAME,
messages=messages
)
return {
"response": final_response.choices[0].message.content,
"tool_calls": [
{
"name": tc.function.name,
"arguments": tc.function.arguments
}
for tc in message.tool_calls
]
}
else:
# Direct response without tool calls
return {
"response": message.content,
"tool_calls": None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# /answer endpoint removed - use /chat for all interactions
@app.get("/datasets")
async def list_datasets():
"""List all available datasets"""
return {"datasets": list(pipelines.keys())}
@app.get("/questions")
async def list_questions(dataset: str = "developer-portfolio"):
"""List all questions for a given dataset"""
if dataset not in pipelines:
raise HTTPException(status_code=400, detail=f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}")
selected_pipeline = pipelines[dataset]
questions = [doc.meta['question'] for doc in selected_pipeline.documents if 'question' in doc.meta]
return {"dataset": dataset, "questions": questions}
async def load_datasets_background():
"""Load datasets in background after server starts"""
global pipelines
# Import RAGPipeline only when needed
from .pipeline import RAGPipeline
# Only load developer-portfolio to save memory
dataset_name = "developer-portfolio"
try:
logger.info(f"Loading dataset: {dataset_name}")
pipeline = RAGPipeline.from_preset(preset_name=dataset_name)
pipelines[dataset_name] = pipeline
logger.info(f"Successfully loaded {dataset_name}")
except Exception as e:
logger.error(f"Failed to load {dataset_name}: {e}")
logger.info(f"Background loading complete - {len(pipelines)} datasets loaded")
@app.on_event("startup")
async def startup_event():
logger.info("FastAPI application startup complete")
logger.info(f"Server should be running on port: {os.getenv('PORT', '8000')}")
# Start loading datasets in background (non-blocking)
import asyncio
asyncio.create_task(load_datasets_background())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("FastAPI application shutting down")
@app.get("/")
async def root():
"""Root endpoint"""
return {"status": "ok", "message": "RAG Pipeline API", "version": "1.0.0", "datasets": list(pipelines.keys())}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
logger.info("Health check called")
loading_status = "complete" if "developer-portfolio" in pipelines else "loading"
return {
"status": "healthy",
"datasets_loaded": len(pipelines),
"total_datasets": 1, # Only loading developer-portfolio
"loading_status": loading_status,
"port": os.getenv('PORT', '8000')
}