Travel_AI / api_main.py
ayayaya12's picture
Update chatbot to use Asia/Bangkok timezone
04dc4d1
import os
import sys
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.responses import PlainTextResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, Union
from datetime import datetime, timezone, timedelta, date
import zoneinfo
import psycopg2
from psycopg2 import pool as psycopg2_pool
from jose import JWTError, jwt
import uvicorn
from dotenv import load_dotenv
import time
import uuid
from src.recommendation_api import (
TourRecommendationRequest,
TourRecommendationResponse,
get_tour_recommendations
)
load_dotenv()
try:
from src.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM
from src.database import conn_pool
from src.graph_builder import graph_app
from src.embedding import embedding_model
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from src.tools import extract_entities_tool, search_tours_tool
except ImportError as e:
print(f"Error importing from src: {e}. Using placeholders. API will likely fail at runtime until this is fixed.")
DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM = [None]*9
conn_pool = None
graph_app = None
embedding_model = None
class HumanMessage:
def __init__(self, content):
self.content = content
class AIMessage:
def __init__(self, content):
self.content = content
BaseMessage = Dict
app = FastAPI(
title="Travel Chatbot, Tour Search & Recommendation API",
version="1.0.0"
)
reusable_oauth2 = HTTPBearer(
scheme_name="Bearer"
)
class EmbeddingRequest(BaseModel):
text: Union[str, List[str]]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
model: str
dimensions: int
class TourSearchRequest(BaseModel):
query: str = Field(..., description="The search query for tours, e.g., 'tôi muốn đi đà nẵng'")
page: int = Field(1, ge=1, description="Current page number for pagination")
limit: int = Field(10, ge=1, le=100, description="Number of items per page")
class TourSummary(BaseModel):
tour_id: Any
title: str
duration: Optional[str] = None
departure_location: Optional[str] = None
destination: Optional[List[str]] = None
region: Optional[str] = None
itinerary: Optional[str] = None
max_participants: Optional[int] = None
departure_id: Optional[Any] = None
start_date: Optional[Union[datetime, date]] = None
price_adult: Optional[float] = None
price_child_120_140: Optional[float] = None
price_child_100_120: Optional[float] = None
promotion_name: Optional[str] = None
promotion_type: Optional[str] = None
promotion_discount: Optional[Any] = None
class PaginatedTourResponse(BaseModel):
currentPage: int
itemsPerPage: int
totalItems: int
totalPages: int
hasNextPage: bool
hasPrevPage: bool
tours: List[TourSummary]
class TokenData(BaseModel):
user_id: Optional[int] = None
async def get_current_user(token: HTTPAuthorizationCredentials = Depends(reusable_oauth2)) -> int:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
user_id: Optional[int] = payload.get("id")
if user_id is None:
user_id = payload.get("userId")
if user_id is None:
print(f"JWT payload does not contain 'id' or 'userId' field. Payload: {payload}")
raise credentials_exception
except JWTError as e:
print(f"JWTError: {e}")
raise credentials_exception
except Exception as e:
print(f"An unexpected error occurred during JWT decoding: {e}")
raise credentials_exception
return user_id
class ChatMessageInput(BaseModel):
message: str = Field(..., description="The text message sent by the user to the chatbot.")
session_id: Optional[str] = Field(None, description="An optional identifier for a specific chat session.")
class ChatResponseOutput(BaseModel):
user_id: int = Field(..., description="The ID of the user (from JWT token).")
response: str = Field(..., description="The chatbot's generated textual response.")
session_id: Optional[str] = Field(None, description="The session identifier, mirrored if provided in input.")
timestamp: datetime = Field(..., description="UTC timestamp of when the response was generated.")
def get_db_connection():
if conn_pool is None:
print("conn_pool is None in get_db_connection. Database module likely not initialized.")
raise HTTPException(status_code=503, detail="Database connection pool not initialized. Check src.database and .env configuration.")
try:
conn = conn_pool.getconn()
yield conn
finally:
if conn:
conn_pool.putconn(conn)
def fetch_conversation_history(db_conn, user_id: int, session_id: Optional[str] = None) -> List[BaseMessage]:
history: List[BaseMessage] = []
try:
with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
if session_id:
cursor.execute(
"SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id = %s ORDER BY interaction_time ASC",
(user_id, session_id)
)
else:
cursor.execute(
"SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id IS NULL ORDER BY interaction_time ASC",
(user_id,)
)
records = cursor.fetchall()
for record in records:
if record["message"]:
history.append(HumanMessage(content=record["message"]))
if record["response"]:
history.append(AIMessage(content=record["response"]))
except Exception as e:
print(f"Error fetching conversation history for user_id {user_id}, session_id {session_id}: {e}")
return history
def save_interaction_to_history(db_conn, user_id: int, user_message: str, chatbot_response: str, session_id: Optional[str] = None):
try:
with db_conn.cursor() as cursor:
cursor.execute(
"INSERT INTO ChatbotHistory (user_id, message, response, interaction_time, session_id) VALUES (%s, %s, %s, %s, %s)",
(user_id, user_message, chatbot_response, datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")), session_id)
)
db_conn.commit()
except Exception as e:
print(f"Error saving interaction to history for user_id {user_id}, session_id {session_id}: {e}")
db_conn.rollback()
@app.on_event("startup")
async def startup_event():
try:
if embedding_model:
embedding_model.load_model()
except Exception as e:
print(f"Failed to load embedding model on startup: {str(e)}")
@app.get("/", include_in_schema=False, response_class=PlainTextResponse)
async def root():
return "API is running"
@app.post("/api/chat/", response_model=ChatResponseOutput, tags=["Chat"], summary="Chat with the Travel Chatbot")
async def chat_endpoint(payload: ChatMessageInput, current_user_id: int = Depends(get_current_user), db_conn = Depends(get_db_connection)):
if graph_app is None:
print("graph_app is None in chat_endpoint. Graph_builder module likely not initialized.")
raise HTTPException(status_code=503, detail="Chatbot graph not initialized. Check src.graph_builder.")
user_id = current_user_id
user_message_content = payload.message
session_id = payload.session_id
if not session_id:
session_id = str(uuid.uuid4())
history = fetch_conversation_history(db_conn, user_id, session_id)
current_message = HumanMessage(content=user_message_content)
all_messages = history + [current_message]
inputs = {
"messages": all_messages,
"user_query": user_message_content,
"current_date": None,
"available_locations": None,
"extracted_entities": None,
"search_results": None,
"final_response": None,
"error": None,
"routing_decision": None
}
full_response_content = ""
try:
result = graph_app.invoke(inputs)
if isinstance(result, dict) and "final_response" in result:
full_response_content = result["final_response"]
elif isinstance(result, dict) and "messages" in result and result["messages"]:
last_message = result["messages"][-1]
if isinstance(last_message, AIMessage):
full_response_content = last_message.content
if not full_response_content:
full_response_content = "Sorry, I could not process your request at this moment."
except Exception as e:
print(f"Error during graph invocation for user_id {user_id}: {e}")
raise HTTPException(status_code=500, detail=f"Error processing message with chatbot: {str(e)}")
save_interaction_to_history(db_conn, user_id, user_message_content, full_response_content, session_id)
return ChatResponseOutput(
user_id=user_id,
response=full_response_content,
session_id=session_id,
timestamp=datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok"))
)
@app.post("/api/tours/search", response_model=PaginatedTourResponse, tags=["Tours"], summary="Search for tours based on user query")
async def search_tours_api(request: TourSearchRequest, current_user_id: int = Depends(get_current_user)):
current_date_str = date.today().strftime('%Y-%m-%d')
entities = extract_entities_tool(user_query=request.query, current_date_str=current_date_str)
if not entities or (isinstance(entities, dict) and entities.get("error")):
return PaginatedTourResponse(
currentPage=request.page,
itemsPerPage=request.limit,
totalItems=0,
totalPages=0,
hasNextPage=False,
hasPrevPage=False,
tours=[]
)
all_found_tours = search_tours_tool(entities)
if not all_found_tours:
return PaginatedTourResponse(
currentPage=request.page,
itemsPerPage=request.limit,
totalItems=0,
totalPages=0,
hasNextPage=False,
hasPrevPage=False,
tours=[]
)
total_items = len(all_found_tours)
total_pages = (total_items + request.limit - 1) // request.limit
start_index = (request.page - 1) * request.limit
end_index = start_index + request.limit
paginated_tours_data = all_found_tours[start_index:end_index]
tour_summaries = []
for tour_data in paginated_tours_data:
region = tour_data.get("region")
if region is not None and not isinstance(region, str):
region = str(region)
destination = tour_data.get("destination")
if destination and not isinstance(destination, list):
destination = [str(destination)]
elif destination:
destination = [str(dest) for dest in destination]
tour_summaries.append(TourSummary(
tour_id=tour_data.get("tour_id"),
title=str(tour_data.get("title", "")),
duration=str(tour_data.get("duration")) if tour_data.get("duration") is not None else None,
departure_location=str(tour_data.get("departure_location")) if tour_data.get("departure_location") is not None else None,
destination=destination,
region=region,
itinerary=str(tour_data.get("itinerary")) if tour_data.get("itinerary") is not None else None,
max_participants=int(tour_data.get("max_participants")) if tour_data.get("max_participants") is not None else None,
departure_id=tour_data.get("departure_id"),
start_date=tour_data.get("start_date"),
price_adult=float(tour_data.get("price_adult")) if tour_data.get("price_adult") is not None else None,
price_child_120_140=float(tour_data.get("price_child_120_140")) if tour_data.get("price_child_120_140") is not None else None,
price_child_100_120=float(tour_data.get("price_child_100_120")) if tour_data.get("price_child_100_120") is not None else None,
promotion_name=str(tour_data.get("promotion_name")) if tour_data.get("promotion_name") is not None else None,
promotion_type=str(tour_data.get("promotion_type")) if tour_data.get("promotion_type") is not None else None,
promotion_discount=tour_data.get("promotion_discount")
))
return PaginatedTourResponse(
currentPage=request.page,
itemsPerPage=request.limit,
totalItems=total_items,
totalPages=total_pages,
hasNextPage=(request.page < total_pages),
hasPrevPage=(request.page > 1),
tours=tour_summaries
)
@app.get("/api/recommendations", response_model=TourRecommendationResponse, tags=["Recommendations"], summary="Get tour recommendations")
async def get_recommendations(
user_id: Optional[int] = None,
tour_id: Optional[int] = None,
limit: int = 3
):
try:
if limit < 1 or limit > 10:
limit = 3
recommendations = get_tour_recommendations(user_id, tour_id, limit)
return recommendations
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}")
@app.post("/api/embed", response_model=EmbeddingResponse, tags=["Embeddings"], summary="Generate text embeddings")
async def get_embedding(request: EmbeddingRequest):
if embedding_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Embedding service not initialized. Check src.embedding module."
)
try:
embeddings = embedding_model.get_embedding(request.text)
return {
"embeddings": embeddings,
"model": embedding_model.model_name,
"dimensions": len(embeddings[0])
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error generating embeddings: {str(e)}"
)