rag_chat_with_analytics_aws / rag_routerv2.py
pvanand's picture
Update rag_routerv2.py
b6c9315 verified
raw
history blame
8.75 kB
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv
import sqlite3
# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil
router = APIRouter(
prefix="/rag",
tags=["rag"]
)
# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
return lancedb.connect(db_path)
def get_db():
conn = sqlite3.connect('./data/tablesv2.db')
conn.row_factory = sqlite3.Row
return conn
def init_db():
db = get_db()
db.execute('''
CREATE TABLE IF NOT EXISTS tables (
id INTEGER PRIMARY KEY,
user_id TEXT NOT NULL,
table_id TEXT NOT NULL,
table_name TEXT NOT NULL,
created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
db.execute('''
CREATE TABLE IF NOT EXISTS table_files (
id INTEGER PRIMARY KEY,
table_id TEXT NOT NULL,
filename TEXT NOT NULL,
file_path TEXT NOT NULL,
FOREIGN KEY (table_id) REFERENCES tables (table_id),
UNIQUE(table_id, filename)
)
''')
db.commit()
# Pydantic models
class CreateTableResponse(BaseModel):
table_id: str
message: str
status: str
table_name: str
class QueryTableResponse(BaseModel):
results: Dict[str, Any]
total_results: int
@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(
user_id: str,
files: List[UploadFile] = File(...),
table_id: Optional[str] = None,
table_name: Optional[str] = None
) -> CreateTableResponse:
try:
db = get_db()
table_id = table_id or str(uuid.uuid4())
table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}"
# Check if table exists
existing = db.execute(
'SELECT id FROM tables WHERE user_id = ? AND table_id = ?',
(user_id, table_id)
).fetchone()
directory_path = f"./data/{table_id}"
os.makedirs(directory_path, exist_ok=True)
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="Invalid filename")
if os.path.splitext(file.filename)[1].lower() not in {".pdf", ".docx", ".csv", ".txt", ".md"}:
raise HTTPException(status_code=400, detail="Unsupported file type")
file_path = os.path.join(directory_path, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
vector_store = LanceDBVectorStore(
uri="./lancedb/dev",
table_name=table_id,
mode="overwrite",
query_type="hybrid"
)
documents = SimpleDirectoryReader(directory_path).load_data()
index = VectorStoreIndex.from_documents(documents, vector_store=vector_store)
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")
if not existing:
db.execute(
'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
(user_id, table_id, table_name)
)
for file in files:
db.execute(
'INSERT OR REPLACE INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
(table_id, file.filename, f"./data/{table_id}/{file.filename}")
)
db.commit()
return CreateTableResponse(
table_id=table_id,
message="Success",
status="success",
table_name=table_name
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(
table_id: str,
query: str,
user_id: str,
#db: Annotated[Any, Depends(get_db_connection)],
limit: Optional[int] = 10
) -> QueryTableResponse:
"""Query the database table using LlamaIndex."""
try:
table_name = table_id #f"{user_id}__table__{table_id}"
# load index and retriever
storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
index = load_index_from_storage(storage_context)
retriever = index.as_retriever(similarity_top_k=limit)
# Get response
response = retriever.retrieve(query)
# Format results
results = [{
'text': node.text,
'score': node.score
} for node in response]
return QueryTableResponse(
results={'data': results},
total_results=len(results)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
db = get_db()
tables = db.execute('''
SELECT
t.table_id,
t.table_name,
t.created_time as created_at,
GROUP_CONCAT(tf.filename) as filenames
FROM tables t
LEFT JOIN table_files tf ON t.table_id = tf.table_id
WHERE t.user_id = ?
GROUP BY t.table_id
''', (user_id,)).fetchall()
result = []
for table in tables:
table_dict = dict(table)
result.append({
'table_id': table_dict['table_id'],
'table_name': table_dict['table_name'],
'created_at': table_dict['created_at'],
'documents': [filename for filename in table_dict['filenames'].split(',') if filename] if table_dict['filenames'] else []
})
return result
@router.delete("/delete_table/{table_id}")
async def delete_table(table_id: str, user_id: str):
try:
db = get_db()
# Verify user owns the table
table = db.execute(
'SELECT * FROM tables WHERE table_id = ? AND user_id = ?',
(table_id, user_id)
).fetchone()
if not table:
raise HTTPException(status_code=404, detail="Table not found or unauthorized")
# Delete files from filesystem
table_path = f"./data/{table_id}"
index_path = f"./lancedb/index/{table_id}"
if os.path.exists(table_path):
shutil.rmtree(table_path)
if os.path.exists(index_path):
shutil.rmtree(index_path)
# Delete from database
db.execute('DELETE FROM table_files WHERE table_id = ?', (table_id,))
db.execute('DELETE FROM tables WHERE table_id = ?', (table_id,))
db.commit()
return {"message": "Table deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/health")
async def health_check():
return {"status": "healthy"}
@router.on_event("startup")
async def startup():
init_db()
print("RAG Router started")
table_name = "digiyatra"
user_id = "digiyatra"
db = get_db()
# Check if table already exists
existing = db.execute('SELECT id FROM tables WHERE table_id = ?', (table_name,)).fetchone()
if not existing:
vector_store = LanceDBVectorStore(
uri="./lancedb/dev",
table_name=table_name,
mode="overwrite",
query_type="hybrid"
)
with open('combined_digi_yatra.csv', newline='') as f:
nodes = [TextNode(text=str(row), id_=str(uuid.uuid4()))
for row in list(csv.reader(f))[1:]]
index = VectorStoreIndex(nodes, vector_store=vector_store)
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
db.execute(
'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
(user_id, table_name, table_name)
)
db.execute(
'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
(table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv')
)
db.commit()
@router.on_event("shutdown")
async def shutdown():
print("RAG Router shutdown")