SB-PoC / service /rag_service.py
Chirapath's picture
Upload 3 files
1628024 verified
#!/usr/bin/env python3
"""
RAG (Retrieval-Augmented Generation) Backend API - Cleaned and Optimized
Integrates OCR, Azure OpenAI embeddings, and PostgreSQL vector storage
"""
import os
import uuid
import asyncio
import requests
import json
import tempfile
import traceback
import logging
from typing import Optional, List, Dict, Any, Union
from datetime import datetime
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, HttpUrl
import uvicorn
# Import unified configuration
try:
from configs import get_config
config = get_config().rag
unified_config = get_config()
print("βœ… Using unified configuration")
except ImportError:
print("⚠️ Unified config not available, using fallback configuration")
from dotenv import load_dotenv
load_dotenv()
class FallbackConfig:
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("RAG_PORT", "8401"))
DEBUG = os.getenv("DEBUG", "True").lower() == "true"
# OCR Service Configuration
OCR_SERVICE_URL = os.getenv("OCR_SERVICE_URL", "http://localhost:8400")
# PostgreSQL Configuration
PG_HOST = os.getenv("POSTGRES_HOST", "")
PG_PORT = int(os.getenv("POSTGRES_PORT", "5432"))
PG_DATABASE = os.getenv("PG_DATABASE", "vectorsearch")
PG_USER = os.getenv("POSTGRES_USER", "")
PG_PASSWORD = os.getenv("POSTGRES_PASSWORD", "")
PG_SSL_MODE = os.getenv("PG_SSL_MODE", "require")
# Azure OpenAI Configuration
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "")
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT", "text-embedding-3-small")
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview")
# Chunking Configuration
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "1000"))
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "200"))
MIN_CHUNK_SIZE = int(os.getenv("MIN_CHUNK_SIZE", "50"))
# Processing limits
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
REQUEST_TIMEOUT = 300
config = FallbackConfig()
import asyncpg
import numpy as np
from openai import AzureOpenAI
import re
from pathlib import Path
from urllib.parse import urlparse
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="RAG Backend API",
description="Retrieval-Augmented Generation service with OCR, embeddings, and vector search",
version="2.0.0",
debug=config.DEBUG
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic Models
class DocumentUploadRequest(BaseModel):
title: Optional[str] = None
keywords: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
chunk_size: Optional[int] = None
chunk_overlap: Optional[int] = None
class URLProcessRequest(BaseModel):
url: HttpUrl
title: Optional[str] = None
keywords: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
extract_images: bool = True
chunk_size: Optional[int] = None
chunk_overlap: Optional[int] = None
class SearchRequest(BaseModel):
query: str
limit: int = 10
similarity_threshold: float = 0.2
filter_metadata: Optional[Dict[str, Any]] = None
class DocumentChunk(BaseModel):
id: str
document_id: str
content: str
chunk_index: int
embedding: Optional[List[float]] = None
metadata: Dict[str, Any]
created_at: datetime
class DocumentResponse(BaseModel):
id: str
title: str
source_type: str
source_url: Optional[str]
total_chunks: int
keywords: List[str]
metadata: Dict[str, Any]
created_at: datetime
processing_status: str
class SearchResult(BaseModel):
chunk: DocumentChunk
similarity_score: float
document_info: Dict[str, Any]
class SearchResponse(BaseModel):
query: str
results: List[SearchResult]
total_results: int
processing_time: float
# Database connection pool
db_pool = None
# UUID generation method cache
_uuid_method = None
async def detect_uuid_method(conn) -> str:
"""Detect and cache the best available UUID generation method"""
global _uuid_method
if _uuid_method is not None:
return _uuid_method
# Test built-in gen_random_uuid() first (PostgreSQL 13+)
try:
await conn.fetchval("SELECT gen_random_uuid()")
_uuid_method = "built-in"
logger.info("Using built-in gen_random_uuid() for UUID generation")
return _uuid_method
except Exception:
pass
# Test uuid-ossp extension
try:
await conn.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
await conn.fetchval("SELECT uuid_generate_v4()")
_uuid_method = "uuid-ossp"
logger.info("Using uuid-ossp extension for UUID generation")
return _uuid_method
except Exception as e:
if "not allow-listed" in str(e) or "not allowlisted" in str(e).lower():
logger.info("uuid-ossp extension not allowlisted (normal for Azure PostgreSQL)")
else:
logger.warning(f"uuid-ossp extension not available: {e}")
# Fall back to Python UUID generation
_uuid_method = "python"
logger.info("Using Python-generated UUIDs")
return _uuid_method
async def get_db_pool():
"""Get database connection pool"""
global db_pool
if db_pool is None:
try:
logger.info(f"Creating database pool with host: {config.PG_HOST}:{config.PG_PORT}")
db_pool = await asyncpg.create_pool(
host=config.PG_HOST,
port=config.PG_PORT,
database=config.PG_DATABASE,
user=config.PG_USER,
password=config.PG_PASSWORD,
ssl=config.PG_SSL_MODE,
min_size=1,
max_size=10,
command_timeout=60
)
except Exception as e:
logger.error(f"Failed to create database pool: {e}")
raise
return db_pool
async def get_db_connection():
"""Get database connection from pool"""
pool = await get_db_pool()
return await pool.acquire()
async def release_db_connection(connection):
"""Release database connection back to pool"""
pool = await get_db_pool()
await pool.release(connection)
# Azure OpenAI Client
def get_openai_client():
"""Initialize Azure OpenAI client"""
if (config.AZURE_OPENAI_ENDPOINT == "" or
config.AZURE_OPENAI_API_KEY == "" or
config.AZURE_OPENAI_ENDPOINT == "YOUR_AZURE_OPENAI_ENDPOINT" or
config.AZURE_OPENAI_API_KEY == "YOUR_AZURE_OPENAI_KEY"):
raise HTTPException(
status_code=500,
detail="Azure OpenAI credentials not configured"
)
return AzureOpenAI(
api_version=config.AZURE_OPENAI_API_VERSION,
azure_endpoint=config.AZURE_OPENAI_ENDPOINT,
api_key=config.AZURE_OPENAI_API_KEY
)
# Text Processing Functions
def clean_text(text: str) -> str:
"""Clean and normalize text"""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text)
# Remove special characters but keep basic punctuation
text = re.sub(r'[^\w\s\.\,\!\?\;\:\-\(\)]', '', text)
return text.strip()
def embedding_to_vector_string(embedding: List[float]) -> str:
"""Convert embedding list to PostgreSQL vector format"""
if not embedding or len(embedding) == 0:
raise ValueError("Embedding cannot be empty")
# Convert to PostgreSQL vector format: '[1.0, 2.0, 3.0]'
vector_str = '[' + ','.join(str(float(x)) for x in embedding) + ']'
return vector_str
def create_text_chunks(text: str, chunk_size: int = None, chunk_overlap: int = None) -> List[str]:
"""Split text into overlapping chunks"""
if chunk_size is None:
chunk_size = config.CHUNK_SIZE
if chunk_overlap is None:
chunk_overlap = config.CHUNK_OVERLAP
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
# Try to break at sentence boundary
if end < len(text):
# Look for sentence endings
sentence_endings = ['. ', '! ', '? ', '\n\n']
for ending in sentence_endings:
last_ending = text.rfind(ending, start, end)
if last_ending != -1:
end = last_ending + len(ending)
break
chunk = text[start:end].strip()
if len(chunk) >= config.MIN_CHUNK_SIZE:
chunks.append(chunk)
# Calculate next start position with overlap
start = end - chunk_overlap
if start >= len(text):
break
return chunks
async def generate_embedding(text: str) -> List[float]:
"""Generate embedding using Azure OpenAI"""
try:
if not text or not text.strip():
raise ValueError("Text cannot be empty")
# Truncate text if it's too long
if len(text) > 8000:
text = text[:8000]
logger.warning("Truncated text for embedding generation")
client = get_openai_client()
response = client.embeddings.create(
input=[text.strip()],
model=config.AZURE_OPENAI_DEPLOYMENT
)
if not response.data or len(response.data) == 0:
raise ValueError("No embedding data returned from Azure OpenAI")
embedding = response.data[0].embedding
if not embedding or len(embedding) == 0:
raise ValueError("Empty embedding returned from Azure OpenAI")
logger.debug(f"Generated embedding with {len(embedding)} dimensions")
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
logger.error(f"Text length: {len(text) if text else 0}")
raise HTTPException(status_code=500, detail=f"Embedding generation failed: {e}")
# OCR Integration
async def process_with_ocr(file_bytes: bytes = None, url: str = None, extract_images: bool = True, filename: str = None) -> Dict[str, Any]:
"""Process document using OCR service"""
try:
logger.info(f"Processing with OCR service at {config.OCR_SERVICE_URL}")
if file_bytes:
# Check if it's a plain text file
is_text_file = False
if filename:
text_extensions = ['.txt', '.md', '.rst', '.log']
if any(filename.lower().endswith(ext) for ext in text_extensions):
is_text_file = True
# For plain text files, bypass OCR
if is_text_file:
try:
content = file_bytes.decode('utf-8')
logger.info(f"Processing plain text file directly: {filename}")
if len(content.strip()) < config.MIN_CHUNK_SIZE:
logger.info(f"Text file {filename} is short ({len(content)} chars) but will process anyway")
return {
'success': True,
'content': content,
'pages': [{
'page_number': 1,
'content_type': 'text',
'text_content': content,
'source': 'direct_text',
'character_count': len(content)
}],
'source_type': 'text_file',
'source_url': None,
'error': None
}
except UnicodeDecodeError:
logger.warning(f"Failed to decode {filename} as UTF-8, sending to OCR service")
# Use OCR service
logger.info(f"Uploading file for OCR processing ({len(file_bytes)} bytes)")
with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file:
temp_file.write(file_bytes)
temp_file.flush()
try:
with open(temp_file.name, 'rb') as f:
files = {
'file': (filename or 'document.pdf', f, 'application/octet-stream')
}
data = {
'extract_images': str(extract_images).lower()
}
response = requests.post(
f"{config.OCR_SERVICE_URL}/ocr/analyze",
files=files,
data=data,
timeout=config.REQUEST_TIMEOUT
)
finally:
try:
os.unlink(temp_file.name)
except:
pass
elif url:
# Process URL with OCR service
logger.info(f"Processing URL for OCR: {url}")
data = {
'url': url,
'extract_images': str(extract_images).lower()
}
response = requests.post(
f"{config.OCR_SERVICE_URL}/ocr/analyze",
data=data,
timeout=config.REQUEST_TIMEOUT
)
else:
raise ValueError("Either file_bytes or url must be provided")
# Check response
logger.info(f"OCR service response status: {response.status_code}")
if response.status_code != 200:
logger.error(f"OCR service error: {response.status_code} - {response.text}")
raise HTTPException(
status_code=500,
detail=f"OCR processing failed: {response.status_code} {response.reason}"
)
result = response.json()
logger.info(f"OCR processing completed successfully. Success: {result.get('success', False)}")
return result
except requests.RequestException as e:
logger.error(f"OCR service request error: {e}")
raise HTTPException(status_code=500, detail=f"OCR service connection failed: {e}")
except Exception as e:
logger.error(f"OCR processing error: {e}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"OCR processing failed: {e}")
# UUID Generation Helper
async def generate_uuid(conn) -> str:
"""Generate UUID using the best available method"""
try:
uuid_method = await detect_uuid_method(conn)
if uuid_method == "built-in":
uuid_val = await conn.fetchval("SELECT gen_random_uuid()")
return str(uuid_val)
elif uuid_method == "uuid-ossp":
uuid_val = await conn.fetchval("SELECT uuid_generate_v4()")
return str(uuid_val)
else:
return str(uuid.uuid4())
except Exception as e:
logger.warning(f"Database UUID generation failed, using Python fallback: {e}")
return str(uuid.uuid4())
# Database Operations
async def create_document_record(
title: str,
source_type: str,
source_url: str = None,
keywords: List[str] = None,
metadata: Dict[str, Any] = None
) -> str:
"""Create document record in database"""
conn = await get_db_connection()
try:
document_id = await generate_uuid(conn)
await conn.execute("""
INSERT INTO documents (id, title, source_type, source_url, keywords, metadata, created_at, processing_status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""", document_id, title, source_type, source_url, keywords or [],
json.dumps(metadata or {}), datetime.utcnow(), "processing")
return document_id
finally:
await release_db_connection(conn)
async def store_document_chunk(
document_id: str,
content: str,
chunk_index: int,
embedding: List[float],
metadata: Dict[str, Any] = None
) -> str:
"""Store document chunk with embedding"""
conn = await get_db_connection()
try:
chunk_id = await generate_uuid(conn)
# Convert embedding to PostgreSQL vector format
embedding_vector = embedding_to_vector_string(embedding)
await conn.execute("""
INSERT INTO document_chunks (id, document_id, content, chunk_index, embedding, metadata, created_at)
VALUES ($1, $2, $3, $4, $5::vector, $6, $7)
""", chunk_id, document_id, content, chunk_index, embedding_vector,
json.dumps(metadata or {}), datetime.utcnow())
return chunk_id
finally:
await release_db_connection(conn)
async def update_document_status(document_id: str, status: str, total_chunks: int = None):
"""Update document processing status"""
conn = await get_db_connection()
try:
if total_chunks is not None:
await conn.execute("""
UPDATE documents SET processing_status = $1, total_chunks = $2 WHERE id = $3
""", status, total_chunks, document_id)
else:
await conn.execute("""
UPDATE documents SET processing_status = $1 WHERE id = $2
""", status, document_id)
finally:
await release_db_connection(conn)
async def search_similar_chunks(
query_embedding: List[float],
limit: int = 10,
similarity_threshold: float = 0.2,
filter_metadata: Dict[str, Any] = None
) -> List[Dict[str, Any]]:
"""Search for similar document chunks using vector similarity"""
conn = await get_db_connection()
try:
logger.info(f"Searching for similar chunks with threshold {similarity_threshold}, limit {limit}")
# Validate inputs
if not query_embedding or len(query_embedding) == 0:
raise ValueError("Query embedding cannot be empty")
logger.info(f"Query embedding dimensions: {len(query_embedding)}")
# Convert query embedding to PostgreSQL vector format
query_vector = embedding_to_vector_string(query_embedding)
# Check if we have any chunks
total_chunks = await conn.fetchval("""
SELECT COUNT(*) FROM document_chunks dc
JOIN documents d ON dc.document_id = d.id
WHERE d.processing_status = 'completed' AND dc.embedding IS NOT NULL
""")
logger.info(f"Total available chunks for search: {total_chunks}")
if total_chunks == 0:
logger.warning("No chunks available for search")
return []
# Build the query
base_query = """
SELECT
dc.id, dc.document_id, dc.content, dc.chunk_index, dc.embedding,
dc.metadata as chunk_metadata, dc.created_at,
d.title, d.source_type, d.source_url, d.keywords, d.metadata as doc_metadata,
1 - (dc.embedding <=> $1::vector) as similarity_score
FROM document_chunks dc
JOIN documents d ON dc.document_id = d.id
WHERE d.processing_status = 'completed'
AND dc.embedding IS NOT NULL
"""
params = [query_vector]
param_count = 1
# Add similarity threshold
if similarity_threshold > 0:
base_query += " AND 1 - (dc.embedding <=> $1::vector) >= $2"
params.append(similarity_threshold)
param_count += 1
# Add metadata filtering
if filter_metadata:
for key, value in filter_metadata.items():
base_query += f" AND d.metadata->>$" + str(param_count + 1) + " = $" + str(param_count + 2)
params.extend([key, str(value)])
param_count += 2
break # Handle only one filter for now
base_query += " ORDER BY similarity_score DESC LIMIT $" + str(param_count + 1)
params.append(limit)
logger.info(f"Executing vector search query with {len(params)} parameters")
try:
rows = await conn.fetch(base_query, *params)
logger.info(f"Vector search query returned {len(rows)} rows")
except Exception as db_error:
logger.error(f"Database query error: {db_error}")
raise HTTPException(status_code=500, detail=f"Vector search query failed: {db_error}")
# Debug: show similarity scores if no results
if len(rows) == 0 and similarity_threshold > 0:
logger.warning(f"No results found with threshold {similarity_threshold}, trying without threshold")
debug_query = """
SELECT
dc.id, dc.content,
1 - (dc.embedding <=> $1::vector) as similarity_score
FROM document_chunks dc
JOIN documents d ON dc.document_id = d.id
WHERE d.processing_status = 'completed'
AND dc.embedding IS NOT NULL
ORDER BY similarity_score DESC
LIMIT 3
"""
debug_rows = await conn.fetch(debug_query, query_vector)
logger.info(f"Debug: Top 3 similarity scores: {[(r['similarity_score'], r['content'][:50]) for r in debug_rows]}")
results = []
for row in rows:
try:
# Safely parse JSON metadata
chunk_metadata = {}
doc_metadata = {}
if row['chunk_metadata']:
try:
chunk_metadata = json.loads(row['chunk_metadata'])
except json.JSONDecodeError:
logger.warning(f"Invalid chunk metadata JSON for chunk {row['id']}")
if row['doc_metadata']:
try:
doc_metadata = json.loads(row['doc_metadata'])
except json.JSONDecodeError:
logger.warning(f"Invalid document metadata JSON for document {row['document_id']}")
# Convert UUID objects to strings
chunk_id = str(row['id']) if row['id'] else None
document_id = str(row['document_id']) if row['document_id'] else None
results.append({
'chunk_id': chunk_id,
'document_id': document_id,
'content': row['content'],
'chunk_index': row['chunk_index'],
'chunk_metadata': chunk_metadata,
'created_at': row['created_at'],
'document_title': row['title'],
'source_type': row['source_type'],
'source_url': row['source_url'],
'keywords': row['keywords'] or [],
'document_metadata': doc_metadata,
'similarity_score': float(row['similarity_score'])
})
except Exception as row_error:
logger.error(f"Error processing search result row: {row_error}")
continue
logger.info(f"Vector search returned {len(results)} results")
if results:
logger.info(f"Top result similarity: {results[0]['similarity_score']:.4f}")
return results
except HTTPException:
raise
except Exception as e:
logger.error(f"Vector search failed: {e}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Vector search failed: {e}")
finally:
await release_db_connection(conn)
# Database initialization
async def init_database():
"""Initialize database tables"""
conn = await get_db_connection()
try:
logger.info("πŸ”„ Initializing database tables...")
# Create documents table
await conn.execute("""
CREATE TABLE IF NOT EXISTS documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
title VARCHAR(500) NOT NULL,
source_type VARCHAR(50) NOT NULL,
source_url TEXT,
keywords TEXT[] DEFAULT '{}',
metadata JSONB DEFAULT '{}',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processing_status VARCHAR(20) DEFAULT 'processing',
total_chunks INTEGER DEFAULT 0
);
""")
# Create document_chunks table
await conn.execute("""
CREATE TABLE IF NOT EXISTS document_chunks (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
content TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
embedding vector(1536),
metadata JSONB DEFAULT '{}',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
# Create indexes
try:
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_documents_status ON documents(processing_status);
CREATE INDEX IF NOT EXISTS idx_chunks_document ON document_chunks(document_id);
CREATE INDEX IF NOT EXISTS idx_chunks_embedding ON document_chunks USING ivfflat (embedding vector_cosine_ops);
""")
except Exception as e:
logger.warning(f"Could not create some indexes (vector extension may not be available): {e}")
logger.info("βœ… Database tables initialized")
finally:
await release_db_connection(conn)
# App Lifecycle
@app.on_event("startup")
async def startup_event():
"""Application startup"""
logger.info("πŸš€ Starting RAG Backend API...")
try:
# Test database connection
await get_db_pool()
logger.info("βœ… Database connection established")
# Initialize database
await init_database()
# Test Azure OpenAI
try:
get_openai_client()
logger.info("βœ… Azure OpenAI client configured")
except Exception as e:
logger.warning(f"⚠️ Azure OpenAI client configuration issue: {e}")
logger.info("πŸŽ‰ RAG Backend API is ready!")
except Exception as e:
logger.error(f"❌ Startup failed: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
logger.info("πŸ›‘ Shutting down RAG Backend API...")
if db_pool:
await db_pool.close()
logger.info("βœ… Database connections closed")
# API Endpoints
@app.get("/")
async def root():
return {
"message": "RAG Backend API",
"version": "2.0.0",
"status": "running",
"features": {
"document_upload": True,
"url_processing": True,
"vector_search": True,
"ocr_integration": True,
"azure_openai_embeddings": True,
"postgresql_vector_storage": True
},
"configuration": {
"chunk_size": config.CHUNK_SIZE,
"chunk_overlap": config.CHUNK_OVERLAP,
"min_chunk_size": config.MIN_CHUNK_SIZE,
"max_file_size_mb": config.MAX_FILE_SIZE / (1024 * 1024)
},
"endpoints": {
"health": "/health",
"docs": "/docs",
"upload": "/documents/upload",
"url_process": "/documents/url",
"search": "/search",
"list_documents": "/documents"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
health_status = {
"status": "unknown",
"service": "RAG Backend API",
"version": "2.0.0",
"timestamp": datetime.utcnow().isoformat(),
"database": "unknown",
"openai": "unknown",
"uuid_method": "unknown",
"ocr_service": "unknown",
"configuration": {
"pg_host": config.PG_HOST,
"pg_port": config.PG_PORT,
"pg_database": config.PG_DATABASE,
"ocr_service_url": config.OCR_SERVICE_URL,
"chunk_size": config.CHUNK_SIZE
},
"errors": []
}
# Test database connection
try:
test_conn = await asyncpg.connect(
host=config.PG_HOST,
port=config.PG_PORT,
database=config.PG_DATABASE,
user=config.PG_USER,
password=config.PG_PASSWORD,
ssl=config.PG_SSL_MODE,
timeout=10
)
db_version = await test_conn.fetchval("SELECT version()")
health_status["database"] = "connected"
health_status["database_version"] = db_version
# Check UUID generation method
uuid_method = await detect_uuid_method(test_conn)
health_status["uuid_method"] = uuid_method
await test_conn.close()
except Exception as db_error:
health_status["database"] = "failed"
health_status["errors"].append(f"Database connection failed: {db_error}")
# Test OpenAI
try:
if (config.AZURE_OPENAI_ENDPOINT == "" or
config.AZURE_OPENAI_API_KEY == ""):
health_status["openai"] = "not_configured"
else:
client = get_openai_client()
# Test with a simple embedding request
test_response = client.embeddings.create(
input=["Health check test"],
model=config.AZURE_OPENAI_DEPLOYMENT
)
if test_response.data:
health_status["openai"] = "configured"
health_status["embedding_dimensions"] = len(test_response.data[0].embedding)
else:
health_status["openai"] = "failed"
health_status["errors"].append("OpenAI embedding test failed")
except Exception as openai_error:
health_status["openai"] = "failed"
health_status["errors"].append(f"OpenAI configuration failed: {openai_error}")
# Test OCR service
try:
ocr_response = requests.get(f"{config.OCR_SERVICE_URL}/health", timeout=5)
if ocr_response.status_code == 200:
health_status["ocr_service"] = "available"
else:
health_status["ocr_service"] = "unavailable"
except:
health_status["ocr_service"] = "unavailable"
# Determine overall status
if health_status["database"] == "connected" and health_status["openai"] in ["configured", "not_configured"]:
health_status["status"] = "healthy"
elif health_status["database"] == "connected":
health_status["status"] = "degraded"
else:
health_status["status"] = "unhealthy"
return health_status
@app.post("/documents/upload")
async def upload_document(
file: UploadFile = File(...),
title: str = Form(None),
keywords: str = Form(None), # JSON string of list
metadata: str = Form(None), # JSON string
chunk_size: int = Form(None),
chunk_overlap: int = Form(None)
):
"""Upload and process a document"""
document_id = None
try:
# Parse form data
keywords_list = json.loads(keywords) if keywords else []
metadata_dict = json.loads(metadata) if metadata else {}
# Set default title
if not title:
title = file.filename or "Untitled Document"
# Read file content
logger.info(f"Processing uploaded file: {file.filename} ({file.content_type})")
file_bytes = await file.read()
if not file_bytes or len(file_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file uploaded")
if len(file_bytes) > config.MAX_FILE_SIZE:
raise HTTPException(status_code=400, detail="File too large")
# Process with OCR
logger.info(f"Processing document with OCR: {title}")
ocr_result = await process_with_ocr(file_bytes=file_bytes, filename=file.filename)
if not ocr_result.get('success', False):
error_msg = ocr_result.get('error', 'Unknown OCR error')
logger.error(f"OCR processing failed: {error_msg}")
raise HTTPException(status_code=400, detail=f"OCR processing failed: {error_msg}")
# Extract text content
content = ocr_result.get('content', '')
if not content or not content.strip():
raise HTTPException(status_code=400, detail="No text content extracted from document")
# Clean the text
cleaned_content = clean_text(content)
if not cleaned_content or len(cleaned_content.strip()) == 0:
raise HTTPException(status_code=400, detail="No text content after cleaning")
# Allow shorter content for testing
if len(cleaned_content.strip()) < config.MIN_CHUNK_SIZE:
logger.warning(f"Content is short ({len(cleaned_content)} chars) but processing anyway")
# Create document record
document_id = await create_document_record(
title=title,
source_type='file_upload',
keywords=keywords_list,
metadata={
**metadata_dict,
'filename': file.filename,
'content_type': file.content_type,
'file_size': len(file_bytes),
'ocr_pages': len(ocr_result.get('pages', []))
}
)
# Create text chunks
chunks = create_text_chunks(
cleaned_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
if not chunks:
raise HTTPException(status_code=400, detail="No valid chunks created from document")
# Process chunks and generate embeddings
logger.info(f"Processing {len(chunks)} chunks for document {document_id}")
successful_chunks = 0
for i, chunk_content in enumerate(chunks):
try:
if not chunk_content or len(chunk_content.strip()) < 10:
logger.warning(f"Skipping chunk {i} - too small")
continue
# Generate embedding
embedding = await generate_embedding(chunk_content)
# Store chunk
await store_document_chunk(
document_id=document_id,
content=chunk_content,
chunk_index=i,
embedding=embedding,
metadata={
'chunk_size': len(chunk_content),
'position': i
}
)
successful_chunks += 1
except Exception as e:
logger.error(f"Failed to process chunk {i} for document {document_id}: {e}")
continue
if successful_chunks == 0:
await update_document_status(document_id, "failed")
raise HTTPException(status_code=500, detail="No chunks could be processed successfully")
# Update document status
await update_document_status(document_id, "completed", successful_chunks)
logger.info(f"Document {document_id} processed successfully with {successful_chunks} chunks")
return {
"success": True,
"document_id": document_id,
"title": title,
"total_chunks": successful_chunks,
"message": "Document processed successfully"
}
except HTTPException:
if document_id:
try:
await update_document_status(document_id, "failed")
except:
pass
raise
except Exception as e:
if document_id:
try:
await update_document_status(document_id, "failed")
except:
pass
logger.error(f"Unexpected error processing document: {e}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Document processing failed: {e}")
@app.post("/documents/url")
async def process_url(request: URLProcessRequest):
"""Process document from URL"""
document_id = None
try:
url_str = str(request.url)
# Set default title
title = request.title or f"Document from {urlparse(url_str).netloc}"
# Process with OCR
logger.info(f"Processing URL with OCR: {url_str}")
ocr_result = await process_with_ocr(url=url_str, extract_images=request.extract_images)
if not ocr_result.get('success', False):
error_msg = ocr_result.get('error', 'Unknown OCR error')
logger.error(f"OCR processing failed for URL: {error_msg}")
raise HTTPException(status_code=400, detail=f"OCR processing failed: {error_msg}")
# Extract text content
content = ocr_result.get('content', '')
if not content or not content.strip():
raise HTTPException(status_code=400, detail="No text content extracted from URL")
# Clean the text
cleaned_content = clean_text(content)
if not cleaned_content or len(cleaned_content.strip()) == 0:
raise HTTPException(status_code=400, detail="No text content after cleaning")
# Allow shorter content for testing
if len(cleaned_content.strip()) < config.MIN_CHUNK_SIZE:
logger.warning(f"URL content is short ({len(cleaned_content)} chars) but processing anyway")
# Create document record
document_id = await create_document_record(
title=title,
source_type=ocr_result.get('source_type', 'url'),
source_url=url_str,
keywords=request.keywords or [],
metadata={
**(request.metadata or {}),
'url': url_str,
'extract_images': request.extract_images,
'ocr_pages': len(ocr_result.get('pages', []))
}
)
# Create text chunks
chunks = create_text_chunks(
cleaned_content,
chunk_size=request.chunk_size,
chunk_overlap=request.chunk_overlap
)
if not chunks:
raise HTTPException(status_code=400, detail="No valid chunks created from URL content")
# Process chunks and generate embeddings
logger.info(f"Processing {len(chunks)} chunks for document {document_id}")
successful_chunks = 0
for i, chunk_content in enumerate(chunks):
try:
if not chunk_content or len(chunk_content.strip()) < 10:
logger.warning(f"Skipping chunk {i} - too small")
continue
# Generate embedding
embedding = await generate_embedding(chunk_content)
# Store chunk
await store_document_chunk(
document_id=document_id,
content=chunk_content,
chunk_index=i,
embedding=embedding,
metadata={
'chunk_size': len(chunk_content),
'position': i
}
)
successful_chunks += 1
except Exception as e:
logger.error(f"Failed to process chunk {i} for document {document_id}: {e}")
continue
if successful_chunks == 0:
await update_document_status(document_id, "failed")
raise HTTPException(status_code=500, detail="No chunks could be processed successfully")
# Update document status
await update_document_status(document_id, "completed", successful_chunks)
logger.info(f"URL document {document_id} processed successfully with {successful_chunks} chunks")
return {
"success": True,
"document_id": document_id,
"title": title,
"total_chunks": successful_chunks,
"source_url": url_str,
"message": "URL processed successfully"
}
except HTTPException:
if document_id:
try:
await update_document_status(document_id, "failed")
except:
pass
raise
except Exception as e:
if document_id:
try:
await update_document_status(document_id, "failed")
except:
pass
logger.error(f"Unexpected error processing URL: {e}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"URL processing failed: {e}")
@app.post("/search", response_model=SearchResponse)
async def search_documents(request: SearchRequest):
"""Search documents using vector similarity"""
try:
import time
start_time = time.time()
# Validate input
if not request.query or not request.query.strip():
raise HTTPException(status_code=400, detail="Query cannot be empty")
query_text = request.query.strip()
logger.info(f"Performing vector search for query: '{query_text}'")
# Generate embedding for query
try:
query_embedding = await generate_embedding(query_text)
except Exception as e:
logger.error(f"Failed to generate query embedding: {e}")
raise HTTPException(status_code=500, detail=f"Query embedding generation failed: {e}")
# Search for similar chunks
try:
results = await search_similar_chunks(
query_embedding=query_embedding,
limit=request.limit,
similarity_threshold=request.similarity_threshold,
filter_metadata=request.filter_metadata
)
except Exception as e:
logger.error(f"Vector search failed: {e}")
raise HTTPException(status_code=500, detail=f"Vector search failed: {e}")
# Format results
search_results = []
for result in results:
try:
chunk = DocumentChunk(
id=result['chunk_id'],
document_id=result['document_id'],
content=result['content'],
chunk_index=result['chunk_index'],
metadata=result['chunk_metadata'],
created_at=result['created_at']
)
search_results.append(SearchResult(
chunk=chunk,
similarity_score=result['similarity_score'],
document_info={
'title': result['document_title'],
'source_type': result['source_type'],
'source_url': result['source_url'],
'keywords': result['keywords'],
'metadata': result['document_metadata']
}
))
except Exception as result_error:
logger.error(f"Error formatting search result: {result_error}")
continue
processing_time = time.time() - start_time
logger.info(f"Search completed: {len(search_results)} results in {processing_time:.3f}s")
return SearchResponse(
query=request.query,
results=search_results,
total_results=len(search_results),
processing_time=processing_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Search failed with unexpected error: {e}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Search failed: {e}")
@app.get("/documents")
async def list_documents(
limit: int = Query(10, ge=1, le=100),
offset: int = Query(0, ge=0),
status: str = Query(None)
):
"""List documents with pagination"""
conn = await get_db_connection()
try:
# Build query
base_query = """
SELECT id, title, source_type, source_url, keywords, metadata,
created_at, processing_status, total_chunks
FROM documents
"""
params = []
if status:
base_query += " WHERE processing_status = $1"
params.append(status)
base_query += " ORDER BY created_at DESC LIMIT $" + str(len(params) + 1) + " OFFSET $" + str(len(params) + 2)
params.extend([limit, offset])
rows = await conn.fetch(base_query, *params)
documents = []
for row in rows:
documents.append({
'id': str(row['id']),
'title': row['title'],
'source_type': row['source_type'],
'source_url': row['source_url'],
'keywords': row['keywords'],
'metadata': json.loads(row['metadata']) if row['metadata'] else {},
'created_at': row['created_at'].isoformat(),
'processing_status': row['processing_status'],
'total_chunks': row['total_chunks']
})
# Get total count
count_query = "SELECT COUNT(*) FROM documents"
if status:
count_query += " WHERE processing_status = $1"
total_count = await conn.fetchval(count_query, status)
else:
total_count = await conn.fetchval(count_query)
return {
"documents": documents,
"total": total_count,
"limit": limit,
"offset": offset
}
finally:
await release_db_connection(conn)
@app.get("/documents/{document_id}")
async def get_document(document_id: str):
"""Get document details"""
conn = await get_db_connection()
try:
# Get document
doc_row = await conn.fetchrow("""
SELECT id, title, source_type, source_url, keywords, metadata,
created_at, processing_status, total_chunks
FROM documents WHERE id = $1
""", document_id)
if not doc_row:
raise HTTPException(status_code=404, detail="Document not found")
# Get chunks
chunk_rows = await conn.fetch("""
SELECT id, content, chunk_index, metadata, created_at
FROM document_chunks
WHERE document_id = $1
ORDER BY chunk_index
""", document_id)
return {
'id': str(doc_row['id']),
'title': doc_row['title'],
'source_type': doc_row['source_type'],
'source_url': doc_row['source_url'],
'keywords': doc_row['keywords'],
'metadata': json.loads(doc_row['metadata']) if doc_row['metadata'] else {},
'created_at': doc_row['created_at'].isoformat(),
'processing_status': doc_row['processing_status'],
'total_chunks': doc_row['total_chunks'],
'chunks': [
{
'id': str(chunk['id']),
'content': chunk['content'],
'chunk_index': chunk['chunk_index'],
'metadata': json.loads(chunk['metadata']) if chunk['metadata'] else {},
'created_at': chunk['created_at'].isoformat()
}
for chunk in chunk_rows
]
}
finally:
await release_db_connection(conn)
@app.delete("/documents/{document_id}")
async def delete_document(document_id: str):
"""Delete document and its chunks"""
conn = await get_db_connection()
try:
# Check if document exists
exists = await conn.fetchval("SELECT EXISTS(SELECT 1 FROM documents WHERE id = $1)", document_id)
if not exists:
raise HTTPException(status_code=404, detail="Document not found")
# Delete chunks first (foreign key constraint)
await conn.execute("DELETE FROM document_chunks WHERE document_id = $1", document_id)
# Delete document
await conn.execute("DELETE FROM documents WHERE id = $1", document_id)
return {"message": "Document deleted successfully"}
finally:
await release_db_connection(conn)
if __name__ == "__main__":
print("πŸ”§ Loading RAG service configuration...")
print(f"🌐 Will start server on {config.HOST}:{config.PORT}")
print(f"πŸ—„οΈ Database: {config.PG_HOST}:{config.PG_PORT}/{config.PG_DATABASE}")
print(f"πŸ€– Azure OpenAI: {'βœ… Configured' if config.AZURE_OPENAI_ENDPOINT else '❌ Not configured'}")
print(f"πŸ” OCR Service: {config.OCR_SERVICE_URL}")
uvicorn.run(
"rag_service:app",
host=config.HOST,
port=config.PORT,
reload=config.DEBUG,
log_level="info"
)