Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File | |
import pandas as pd | |
import lancedb | |
from functools import cached_property, lru_cache | |
from pydantic import Field, BaseModel | |
from typing import Optional, Dict, List, Annotated, Any | |
from fastapi import APIRouter | |
import uuid | |
import io | |
from io import BytesIO | |
import csv | |
import sqlite3 | |
# LlamaIndex imports | |
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex | |
from llama_index.vector_stores.lancedb import LanceDBVectorStore | |
from llama_index.embeddings.fastembed import FastEmbedEmbedding | |
from llama_index.core.schema import TextNode | |
from llama_index.core import StorageContext, load_index_from_storage | |
import json | |
import os | |
import shutil | |
router = APIRouter( | |
prefix="/rag", | |
tags=["rag"] | |
) | |
# Configure global LlamaIndex settings | |
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
# Database connection dependency | |
def get_db_connection(db_path: str = "./lancedb/dev"): | |
return lancedb.connect(db_path) | |
def get_db(): | |
conn = sqlite3.connect('./data/tablesv2.db') | |
conn.row_factory = sqlite3.Row | |
return conn | |
def init_db(): | |
db = get_db() | |
db.execute(''' | |
CREATE TABLE IF NOT EXISTS tables ( | |
id INTEGER PRIMARY KEY, | |
user_id TEXT NOT NULL, | |
table_id TEXT NOT NULL, | |
table_name TEXT NOT NULL, | |
created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
) | |
''') | |
db.execute(''' | |
CREATE TABLE IF NOT EXISTS table_files ( | |
id INTEGER PRIMARY KEY, | |
table_id TEXT NOT NULL, | |
filename TEXT NOT NULL, | |
file_path TEXT NOT NULL, | |
FOREIGN KEY (table_id) REFERENCES tables (table_id), | |
UNIQUE(table_id, filename) | |
) | |
''') | |
db.commit() | |
# Pydantic models | |
class CreateTableResponse(BaseModel): | |
table_id: str | |
message: str | |
status: str | |
table_name: str | |
class QueryTableResponse(BaseModel): | |
results: Dict[str, Any] | |
total_results: int | |
async def create_embedding_table( | |
user_id: str, | |
files: List[UploadFile] = File(...), | |
table_id: Optional[str] = None, | |
table_name: Optional[str] = None | |
) -> CreateTableResponse: | |
try: | |
db = get_db() | |
table_id = table_id or str(uuid.uuid4()) | |
table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}" | |
# Check if table exists | |
existing = db.execute( | |
'SELECT id FROM tables WHERE user_id = ? AND table_id = ?', | |
(user_id, table_id) | |
).fetchone() | |
directory_path = f"./data/{table_id}" | |
os.makedirs(directory_path, exist_ok=True) | |
for file in files: | |
if not file.filename: | |
raise HTTPException(status_code=400, detail="Invalid filename") | |
if os.path.splitext(file.filename)[1].lower() not in {".pdf", ".docx", ".csv", ".txt", ".md"}: | |
raise HTTPException(status_code=400, detail="Unsupported file type") | |
file_path = os.path.join(directory_path, file.filename) | |
with open(file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
vector_store = LanceDBVectorStore( | |
uri="./lancedb/dev", | |
table_name=table_id, | |
mode="overwrite", | |
query_type="hybrid" | |
) | |
documents = SimpleDirectoryReader(directory_path).load_data() | |
index = VectorStoreIndex.from_documents(documents, vector_store=vector_store) | |
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}") | |
if not existing: | |
db.execute( | |
'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', | |
(user_id, table_id, table_name) | |
) | |
for file in files: | |
db.execute( | |
'INSERT OR REPLACE INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', | |
(table_id, file.filename, f"./data/{table_id}/{file.filename}") | |
) | |
db.commit() | |
return CreateTableResponse( | |
table_id=table_id, | |
message="Success", | |
status="success", | |
table_name=table_name | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def query_table( | |
table_id: str, | |
query: str, | |
user_id: str, | |
#db: Annotated[Any, Depends(get_db_connection)], | |
limit: Optional[int] = 10 | |
) -> QueryTableResponse: | |
"""Query the database table using LlamaIndex.""" | |
try: | |
table_name = table_id #f"{user_id}__table__{table_id}" | |
# load index and retriever | |
storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}") | |
index = load_index_from_storage(storage_context) | |
retriever = index.as_retriever(similarity_top_k=limit) | |
# Get response | |
response = retriever.retrieve(query) | |
# Format results | |
results = [{ | |
'text': node.text, | |
'score': node.score | |
} for node in response] | |
return QueryTableResponse( | |
results={'data': results}, | |
total_results=len(results) | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") | |
async def get_tables(user_id: str): | |
db = get_db() | |
tables = db.execute(''' | |
SELECT | |
t.table_id, | |
t.table_name, | |
t.created_time as created_at, | |
GROUP_CONCAT(tf.filename) as filenames | |
FROM tables t | |
LEFT JOIN table_files tf ON t.table_id = tf.table_id | |
WHERE t.user_id = ? | |
GROUP BY t.table_id | |
''', (user_id,)).fetchall() | |
result = [] | |
for table in tables: | |
table_dict = dict(table) | |
result.append({ | |
'table_id': table_dict['table_id'], | |
'table_name': table_dict['table_name'], | |
'created_at': table_dict['created_at'], | |
'documents': [filename for filename in table_dict['filenames'].split(',') if filename] if table_dict['filenames'] else [] | |
}) | |
return result | |
async def delete_table(table_id: str, user_id: str): | |
try: | |
db = get_db() | |
# Verify user owns the table | |
table = db.execute( | |
'SELECT * FROM tables WHERE table_id = ? AND user_id = ?', | |
(table_id, user_id) | |
).fetchone() | |
if not table: | |
raise HTTPException(status_code=404, detail="Table not found or unauthorized") | |
# Delete files from filesystem | |
table_path = f"./data/{table_id}" | |
index_path = f"./lancedb/index/{table_id}" | |
if os.path.exists(table_path): | |
shutil.rmtree(table_path) | |
if os.path.exists(index_path): | |
shutil.rmtree(index_path) | |
# Delete from database | |
db.execute('DELETE FROM table_files WHERE table_id = ?', (table_id,)) | |
db.execute('DELETE FROM tables WHERE table_id = ?', (table_id,)) | |
db.commit() | |
return {"message": "Table deleted successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy"} | |
async def startup(): | |
init_db() | |
print("RAG Router started") | |
table_name = "digiyatra" | |
user_id = "digiyatra" | |
db = get_db() | |
# Check if table already exists | |
existing = db.execute('SELECT id FROM tables WHERE table_id = ?', (table_name,)).fetchone() | |
if not existing: | |
vector_store = LanceDBVectorStore( | |
uri="./lancedb/dev", | |
table_name=table_name, | |
mode="overwrite", | |
query_type="hybrid" | |
) | |
with open('combined_digi_yatra.csv', newline='') as f: | |
nodes = [TextNode(text=str(row), id_=str(uuid.uuid4())) | |
for row in list(csv.reader(f))[1:]] | |
index = VectorStoreIndex(nodes, vector_store=vector_store) | |
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}") | |
db.execute( | |
'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', | |
(user_id, table_name, table_name) | |
) | |
db.execute( | |
'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', | |
(table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv') | |
) | |
db.commit() | |
async def shutdown(): | |
print("RAG Router shutdown") |