|
import os |
|
import io |
|
import json |
|
import datetime |
|
import pandas as pd |
|
from datetime import timedelta |
|
from datetime import datetime as dt |
|
from fastapi import FastAPI,WebSocket, Depends, HTTPException, status, UploadFile, File |
|
from fastapi.responses import FileResponse, JSONResponse |
|
from fastapi.requests import Request |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.security import OAuth2PasswordRequestForm |
|
from pydantic import ValidationError, BaseModel |
|
from bson import ObjectId |
|
|
|
|
|
from .auth import ( |
|
get_current_user, |
|
create_access_token, |
|
verify_password, |
|
get_password_hash |
|
) |
|
from .db.models import User, Token, Opportunity |
|
from .db.database import get_user_by_username, create_user, save_file, create_opportunity, get_opportunities, get_opportunity_count |
|
from .websocket import handle_websocket |
|
from .llm_models import invoke_general_model, invoke_customer_search |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
frontend_dir = os.path.join(current_dir, "..", "..", "frontend", "dist") |
|
|
|
@app.get("/") |
|
async def serve_react_root() -> FileResponse: |
|
return FileResponse(os.path.join(frontend_dir, "index.html")) |
|
|
|
@app.post("/register", response_model=Token) |
|
async def register(user: User) -> Token: |
|
if await get_user_by_username(user.username): |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail="Username already registered" |
|
) |
|
|
|
hashed_password = get_password_hash(user.password) |
|
user.password = hashed_password |
|
|
|
if not await create_user(user): |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="Could not create user" |
|
) |
|
|
|
access_token = create_access_token( |
|
data={"sub": user.username}, |
|
expires_delta=timedelta(minutes=30) |
|
) |
|
return Token(access_token=access_token, token_type="bearer") |
|
|
|
@app.post("/token", response_model=Token) |
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> Token: |
|
user = await get_user_by_username(form_data.username) |
|
if not user or not verify_password(form_data.password, user.password): |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Incorrect username or password", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
|
|
access_token = create_access_token( |
|
data={"sub": user.username}, |
|
expires_delta=timedelta(minutes=30) |
|
) |
|
return Token(access_token=access_token, token_type="bearer") |
|
|
|
@app.post("/upload") |
|
async def upload_file( |
|
file: UploadFile = File(...), |
|
current_user: User = Depends(get_current_user) |
|
) -> dict: |
|
contents = await file.read() |
|
|
|
df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) |
|
|
|
|
|
records = json.loads(df.to_json(orient='records')) |
|
|
|
if not await save_file(current_user.username, records, file.filename): |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="Could not save file" |
|
) |
|
|
|
return {"message": "File uploaded successfully"} |
|
|
|
@app.post("/api/save_opportunity") |
|
async def save_opportunity(opportunity_data: dict, current_user: User = Depends(get_current_user)) -> dict: |
|
try: |
|
opportunity_data= { |
|
**opportunity_data, |
|
"username":current_user.username, |
|
"created_at":datetime.datetime.now(datetime.UTC), |
|
"updated_at":datetime.datetime.now(datetime.UTC) |
|
} |
|
print("data********", opportunity_data) |
|
opportunity = Opportunity(**opportunity_data) |
|
if not await create_opportunity(opportunity): |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="Could not save opportunity" |
|
) |
|
return {"message": "Opportunity saved successfully"} |
|
except ValidationError as e: |
|
print(f"Validation error: {e}") |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail="Invalid opportunity data" |
|
) |
|
|
|
@app.get("/api/opportunities") |
|
async def retrieve_opportunities( |
|
request: Request, |
|
page: int = 1, |
|
limit: int = 100, |
|
current_user: User = Depends(get_current_user) |
|
) -> JSONResponse: |
|
""" |
|
Retrieve paginated opportunities for the current user |
|
""" |
|
class JSONEncoder(json.JSONEncoder): |
|
def default(self, obj): |
|
if isinstance(obj, dt): |
|
return obj.isoformat() |
|
if isinstance(obj, ObjectId): |
|
return str(obj) |
|
return super().default(obj) |
|
|
|
try: |
|
skip = (page - 1) * limit |
|
|
|
|
|
records = await get_opportunities( |
|
username=current_user.username, |
|
skip=skip, |
|
limit=limit |
|
) |
|
|
|
|
|
all_records = [] |
|
for record in records: |
|
|
|
record_dict = record.dict(by_alias=True) |
|
if "_id" in record_dict: |
|
record_dict["_id"] = str(record_dict["_id"]) |
|
|
|
|
|
if hasattr(record, 'content') and isinstance(record.content, (list, tuple)): |
|
all_records.extend(record.content) |
|
else: |
|
all_records.append(record_dict) |
|
|
|
|
|
total_count = await get_opportunity_count(current_user.username) |
|
|
|
|
|
response_data = PaginatedResponse( |
|
page=page, |
|
limit=limit, |
|
total_records=total_count, |
|
total_pages=-(-total_count // limit), |
|
has_more=(skip + limit) < total_count, |
|
records=all_records |
|
) |
|
|
|
|
|
return JSONResponse( |
|
content=json.loads( |
|
json.dumps( |
|
{ |
|
"success": True, |
|
"data": response_data.model_dump() |
|
}, |
|
cls=JSONEncoder |
|
) |
|
), |
|
status_code=200 |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error retrieving opportunities: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"An error occurred while retrieving opportunities: {str(e)}" |
|
) |
|
|
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket) -> None: |
|
await handle_websocket(websocket) |
|
|
|
@app.post("/api/message") |
|
async def message(obj: dict, current_user: User = Depends(get_current_user)) -> JSONResponse: |
|
"""Endpoint to handle general incoming messages from the frontend.""" |
|
answer = invoke_general_model(obj["message"]) |
|
print("answer**********", answer) |
|
return JSONResponse(content={"message": json.loads(answer.model_dump_json() )}) |
|
|
|
@app.post("/api/customer_insights") |
|
async def customer_insights(obj: dict) -> JSONResponse: |
|
"""Endpoint to launch a customer insight search.""" |
|
answer = invoke_customer_search(obj["message"]) |
|
return JSONResponse(content={"AIMessage": answer.model_dump_json()}) |
|
|
|
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="static") |
|
|
|
class PaginatedResponse(BaseModel): |
|
page: int |
|
limit: int |
|
total_records: int |
|
total_pages: int |
|
has_more: bool |
|
records: list |
|
|
|
class Config: |
|
json_encoders = { |
|
datetime: lambda v: v.isoformat(), |
|
ObjectId: lambda v: str(v) |
|
} |
|
if __name__ == "__main__": |
|
from fastapi.testclient import TestClient |
|
|
|
client = TestClient(app) |
|
|
|
def test_message_endpoint(): |
|
|
|
|
|
response = client.post("/api/message", json={"message": "What is MEDDPICC?"}) |
|
print(response.json()) |
|
assert response["status_code"] == 200 |
|
assert "AIMessage" in response.json() |
|
|
|
test_message_endpoint() |