|
import os |
|
import json |
|
from typing import List, Dict, Any, Optional |
|
from datetime import datetime |
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Query |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
|
|
from buffalo_rag.scraper.scraper import BuffaloScraper |
|
from buffalo_rag.embeddings.chunker import DocumentChunker |
|
from buffalo_rag.vector_store.db import VectorStore |
|
from buffalo_rag.model.rag import BuffaloRAG |
|
|
|
|
|
app = FastAPI( |
|
title="BuffaloRAG API", |
|
description="API for BuffaloRAG - AI Assistant for International Students at University at Buffalo", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
vector_store = VectorStore() |
|
rag = BuffaloRAG(vector_store=vector_store) |
|
|
|
|
|
class QueryRequest(BaseModel): |
|
query: str |
|
k: int = 5 |
|
categories: Optional[List[str]] = None |
|
|
|
class QueryResponse(BaseModel): |
|
query: str |
|
response: str |
|
sources: List[Dict[str, Any]] |
|
timestamp: str |
|
|
|
class ScrapeRequest(BaseModel): |
|
seed_url: str = "https://www.buffalo.edu/international-student-services.html" |
|
max_pages: int = 100 |
|
|
|
class ScrapeResponse(BaseModel): |
|
status: str |
|
message: str |
|
|
|
|
|
def run_scraper(seed_url: str, max_pages: int): |
|
"""Run the web scraper in the background.""" |
|
scraper = BuffaloScraper(seed_url=seed_url) |
|
scraper.scrape(max_pages=max_pages) |
|
|
|
|
|
chunker = DocumentChunker() |
|
chunks = chunker.create_chunks() |
|
chunker.create_embeddings(chunks) |
|
|
|
|
|
global vector_store |
|
vector_store = VectorStore() |
|
|
|
|
|
global rag |
|
rag = BuffaloRAG(vector_store=vector_store) |
|
|
|
def refresh_index(): |
|
"""Refresh the vector index in the background.""" |
|
chunker = DocumentChunker() |
|
chunks = chunker.create_chunks() |
|
chunker.create_embeddings(chunks) |
|
|
|
|
|
global vector_store |
|
vector_store = VectorStore() |
|
|
|
|
|
global rag |
|
rag = BuffaloRAG(vector_store=vector_store) |
|
|
|
|
|
static_dir = os.path.join(os.path.dirname(__file__), "static") |
|
os.makedirs(static_dir, exist_ok=True) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
|
|
|
@app.post("/api/ask", response_model=QueryResponse) |
|
async def ask(request: QueryRequest): |
|
"""Ask a question to the RAG system.""" |
|
try: |
|
response = rag.answer( |
|
query=request.query, |
|
k=request.k, |
|
filter_categories=request.categories |
|
) |
|
|
|
|
|
response['timestamp'] = datetime.now().isoformat() |
|
|
|
|
|
with open("data/query_log.jsonl", "a") as f: |
|
f.write(json.dumps(response) + "\n") |
|
|
|
return response |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/api/scrape", response_model=ScrapeResponse) |
|
async def scrape(request: ScrapeRequest, background_tasks: BackgroundTasks): |
|
"""Trigger web scraping.""" |
|
try: |
|
background_tasks.add_task(run_scraper, request.seed_url, request.max_pages) |
|
return { |
|
"status": "success", |
|
"message": f"Started scraping from {request.seed_url} (max {request.max_pages} pages)" |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/api/refresh-index", response_model=ScrapeResponse) |
|
async def refresh(background_tasks: BackgroundTasks): |
|
"""Refresh the vector index.""" |
|
try: |
|
background_tasks.add_task(refresh_index) |
|
return { |
|
"status": "success", |
|
"message": "Started refreshing the vector index" |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
async def serve_frontend(): |
|
return FileResponse(os.path.join(static_dir, "index.html")) |
|
|
|
@app.get("/{path:path}", include_in_schema=False) |
|
async def serve_frontend_paths(path: str): |
|
|
|
file_path = os.path.join(static_dir, path) |
|
if os.path.isfile(file_path): |
|
return FileResponse(file_path) |
|
|
|
|
|
return FileResponse(os.path.join(static_dir, "index.html")) |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("buffalo_rag.api.main:app", host="localhost", port=8000, reload=True) |