import os import sys from fastapi import FastAPI, HTTPException, Depends, status from fastapi.responses import PlainTextResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any, Union from datetime import datetime, timezone, timedelta, date import zoneinfo import psycopg2 from psycopg2 import pool as psycopg2_pool from jose import JWTError, jwt import uvicorn from dotenv import load_dotenv import time import uuid from src.recommendation_api import ( TourRecommendationRequest, TourRecommendationResponse, get_tour_recommendations ) load_dotenv() try: from src.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM from src.database import conn_pool from src.graph_builder import graph_app from src.embedding import embedding_model from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from src.tools import extract_entities_tool, search_tours_tool except ImportError as e: print(f"Error importing from src: {e}. Using placeholders. API will likely fail at runtime until this is fixed.") DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM = [None]*9 conn_pool = None graph_app = None embedding_model = None class HumanMessage: def __init__(self, content): self.content = content class AIMessage: def __init__(self, content): self.content = content BaseMessage = Dict app = FastAPI( title="Travel Chatbot, Tour Search & Recommendation API", version="1.0.0" ) reusable_oauth2 = HTTPBearer( scheme_name="Bearer" ) class EmbeddingRequest(BaseModel): text: Union[str, List[str]] class EmbeddingResponse(BaseModel): embeddings: List[List[float]] model: str dimensions: int class TourSearchRequest(BaseModel): query: str = Field(..., description="The search query for tours, e.g., 'tôi muốn đi đà nẵng'") page: int = Field(1, ge=1, description="Current page number for pagination") limit: int = Field(10, ge=1, le=100, description="Number of items per page") class TourSummary(BaseModel): tour_id: Any title: str duration: Optional[str] = None departure_location: Optional[str] = None destination: Optional[List[str]] = None region: Optional[str] = None itinerary: Optional[str] = None max_participants: Optional[int] = None departure_id: Optional[Any] = None start_date: Optional[Union[datetime, date]] = None price_adult: Optional[float] = None price_child_120_140: Optional[float] = None price_child_100_120: Optional[float] = None promotion_name: Optional[str] = None promotion_type: Optional[str] = None promotion_discount: Optional[Any] = None class PaginatedTourResponse(BaseModel): currentPage: int itemsPerPage: int totalItems: int totalPages: int hasNextPage: bool hasPrevPage: bool tours: List[TourSummary] class TokenData(BaseModel): user_id: Optional[int] = None async def get_current_user(token: HTTPAuthorizationCredentials = Depends(reusable_oauth2)) -> int: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM]) user_id: Optional[int] = payload.get("id") if user_id is None: user_id = payload.get("userId") if user_id is None: print(f"JWT payload does not contain 'id' or 'userId' field. Payload: {payload}") raise credentials_exception except JWTError as e: print(f"JWTError: {e}") raise credentials_exception except Exception as e: print(f"An unexpected error occurred during JWT decoding: {e}") raise credentials_exception return user_id class ChatMessageInput(BaseModel): message: str = Field(..., description="The text message sent by the user to the chatbot.") session_id: Optional[str] = Field(None, description="An optional identifier for a specific chat session.") class ChatResponseOutput(BaseModel): user_id: int = Field(..., description="The ID of the user (from JWT token).") response: str = Field(..., description="The chatbot's generated textual response.") session_id: Optional[str] = Field(None, description="The session identifier, mirrored if provided in input.") timestamp: datetime = Field(..., description="UTC timestamp of when the response was generated.") def get_db_connection(): if conn_pool is None: print("conn_pool is None in get_db_connection. Database module likely not initialized.") raise HTTPException(status_code=503, detail="Database connection pool not initialized. Check src.database and .env configuration.") try: conn = conn_pool.getconn() yield conn finally: if conn: conn_pool.putconn(conn) def fetch_conversation_history(db_conn, user_id: int, session_id: Optional[str] = None) -> List[BaseMessage]: history: List[BaseMessage] = [] try: with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: if session_id: cursor.execute( "SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id = %s ORDER BY interaction_time ASC", (user_id, session_id) ) else: cursor.execute( "SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id IS NULL ORDER BY interaction_time ASC", (user_id,) ) records = cursor.fetchall() for record in records: if record["message"]: history.append(HumanMessage(content=record["message"])) if record["response"]: history.append(AIMessage(content=record["response"])) except Exception as e: print(f"Error fetching conversation history for user_id {user_id}, session_id {session_id}: {e}") return history def save_interaction_to_history(db_conn, user_id: int, user_message: str, chatbot_response: str, session_id: Optional[str] = None): try: with db_conn.cursor() as cursor: cursor.execute( "INSERT INTO ChatbotHistory (user_id, message, response, interaction_time, session_id) VALUES (%s, %s, %s, %s, %s)", (user_id, user_message, chatbot_response, datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")), session_id) ) db_conn.commit() except Exception as e: print(f"Error saving interaction to history for user_id {user_id}, session_id {session_id}: {e}") db_conn.rollback() @app.on_event("startup") async def startup_event(): try: if embedding_model: embedding_model.load_model() except Exception as e: print(f"Failed to load embedding model on startup: {str(e)}") @app.get("/", include_in_schema=False, response_class=PlainTextResponse) async def root(): return "API is running" @app.post("/api/chat/", response_model=ChatResponseOutput, tags=["Chat"], summary="Chat with the Travel Chatbot") async def chat_endpoint(payload: ChatMessageInput, current_user_id: int = Depends(get_current_user), db_conn = Depends(get_db_connection)): if graph_app is None: print("graph_app is None in chat_endpoint. Graph_builder module likely not initialized.") raise HTTPException(status_code=503, detail="Chatbot graph not initialized. Check src.graph_builder.") user_id = current_user_id user_message_content = payload.message session_id = payload.session_id if not session_id: session_id = str(uuid.uuid4()) history = fetch_conversation_history(db_conn, user_id, session_id) current_message = HumanMessage(content=user_message_content) all_messages = history + [current_message] inputs = { "messages": all_messages, "user_query": user_message_content, "current_date": None, "available_locations": None, "extracted_entities": None, "search_results": None, "final_response": None, "error": None, "routing_decision": None } full_response_content = "" try: result = graph_app.invoke(inputs) if isinstance(result, dict) and "final_response" in result: full_response_content = result["final_response"] elif isinstance(result, dict) and "messages" in result and result["messages"]: last_message = result["messages"][-1] if isinstance(last_message, AIMessage): full_response_content = last_message.content if not full_response_content: full_response_content = "Sorry, I could not process your request at this moment." except Exception as e: print(f"Error during graph invocation for user_id {user_id}: {e}") raise HTTPException(status_code=500, detail=f"Error processing message with chatbot: {str(e)}") save_interaction_to_history(db_conn, user_id, user_message_content, full_response_content, session_id) return ChatResponseOutput( user_id=user_id, response=full_response_content, session_id=session_id, timestamp=datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")) ) @app.post("/api/tours/search", response_model=PaginatedTourResponse, tags=["Tours"], summary="Search for tours based on user query") async def search_tours_api(request: TourSearchRequest, current_user_id: int = Depends(get_current_user)): current_date_str = date.today().strftime('%Y-%m-%d') entities = extract_entities_tool(user_query=request.query, current_date_str=current_date_str) if not entities or (isinstance(entities, dict) and entities.get("error")): return PaginatedTourResponse( currentPage=request.page, itemsPerPage=request.limit, totalItems=0, totalPages=0, hasNextPage=False, hasPrevPage=False, tours=[] ) all_found_tours = search_tours_tool(entities) if not all_found_tours: return PaginatedTourResponse( currentPage=request.page, itemsPerPage=request.limit, totalItems=0, totalPages=0, hasNextPage=False, hasPrevPage=False, tours=[] ) total_items = len(all_found_tours) total_pages = (total_items + request.limit - 1) // request.limit start_index = (request.page - 1) * request.limit end_index = start_index + request.limit paginated_tours_data = all_found_tours[start_index:end_index] tour_summaries = [] for tour_data in paginated_tours_data: region = tour_data.get("region") if region is not None and not isinstance(region, str): region = str(region) destination = tour_data.get("destination") if destination and not isinstance(destination, list): destination = [str(destination)] elif destination: destination = [str(dest) for dest in destination] tour_summaries.append(TourSummary( tour_id=tour_data.get("tour_id"), title=str(tour_data.get("title", "")), duration=str(tour_data.get("duration")) if tour_data.get("duration") is not None else None, departure_location=str(tour_data.get("departure_location")) if tour_data.get("departure_location") is not None else None, destination=destination, region=region, itinerary=str(tour_data.get("itinerary")) if tour_data.get("itinerary") is not None else None, max_participants=int(tour_data.get("max_participants")) if tour_data.get("max_participants") is not None else None, departure_id=tour_data.get("departure_id"), start_date=tour_data.get("start_date"), price_adult=float(tour_data.get("price_adult")) if tour_data.get("price_adult") is not None else None, price_child_120_140=float(tour_data.get("price_child_120_140")) if tour_data.get("price_child_120_140") is not None else None, price_child_100_120=float(tour_data.get("price_child_100_120")) if tour_data.get("price_child_100_120") is not None else None, promotion_name=str(tour_data.get("promotion_name")) if tour_data.get("promotion_name") is not None else None, promotion_type=str(tour_data.get("promotion_type")) if tour_data.get("promotion_type") is not None else None, promotion_discount=tour_data.get("promotion_discount") )) return PaginatedTourResponse( currentPage=request.page, itemsPerPage=request.limit, totalItems=total_items, totalPages=total_pages, hasNextPage=(request.page < total_pages), hasPrevPage=(request.page > 1), tours=tour_summaries ) @app.get("/api/recommendations", response_model=TourRecommendationResponse, tags=["Recommendations"], summary="Get tour recommendations") async def get_recommendations( user_id: Optional[int] = None, tour_id: Optional[int] = None, limit: int = 3 ): try: if limit < 1 or limit > 10: limit = 3 recommendations = get_tour_recommendations(user_id, tour_id, limit) return recommendations except Exception as e: raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}") @app.post("/api/embed", response_model=EmbeddingResponse, tags=["Embeddings"], summary="Generate text embeddings") async def get_embedding(request: EmbeddingRequest): if embedding_model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Embedding service not initialized. Check src.embedding module." ) try: embeddings = embedding_model.get_embedding(request.text) return { "embeddings": embeddings, "model": embedding_model.model_name, "dimensions": len(embeddings[0]) } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error generating embeddings: {str(e)}" )