|
import os |
|
import sys |
|
from fastapi import FastAPI, HTTPException, Depends, status |
|
from fastapi.responses import PlainTextResponse |
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
from pydantic import BaseModel, Field |
|
from typing import Optional, List, Dict, Any, Union |
|
from datetime import datetime, timezone, timedelta, date |
|
import zoneinfo |
|
import psycopg2 |
|
from psycopg2 import pool as psycopg2_pool |
|
from jose import JWTError, jwt |
|
import uvicorn |
|
from dotenv import load_dotenv |
|
import time |
|
import uuid |
|
from src.recommendation_api import ( |
|
TourRecommendationRequest, |
|
TourRecommendationResponse, |
|
get_tour_recommendations |
|
) |
|
|
|
load_dotenv() |
|
|
|
try: |
|
from src.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM |
|
from src.database import conn_pool |
|
from src.graph_builder import graph_app |
|
from src.embedding import embedding_model |
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage |
|
from src.tools import extract_entities_tool, search_tours_tool |
|
except ImportError as e: |
|
print(f"Error importing from src: {e}. Using placeholders. API will likely fail at runtime until this is fixed.") |
|
DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM = [None]*9 |
|
conn_pool = None |
|
graph_app = None |
|
embedding_model = None |
|
class HumanMessage: |
|
def __init__(self, content): |
|
self.content = content |
|
class AIMessage: |
|
def __init__(self, content): |
|
self.content = content |
|
BaseMessage = Dict |
|
|
|
app = FastAPI( |
|
title="Travel Chatbot, Tour Search & Recommendation API", |
|
version="1.0.0" |
|
) |
|
|
|
reusable_oauth2 = HTTPBearer( |
|
scheme_name="Bearer" |
|
) |
|
|
|
class EmbeddingRequest(BaseModel): |
|
text: Union[str, List[str]] |
|
|
|
class EmbeddingResponse(BaseModel): |
|
embeddings: List[List[float]] |
|
model: str |
|
dimensions: int |
|
|
|
class TourSearchRequest(BaseModel): |
|
query: str = Field(..., description="The search query for tours, e.g., 'tôi muốn đi đà nẵng'") |
|
page: int = Field(1, ge=1, description="Current page number for pagination") |
|
limit: int = Field(10, ge=1, le=100, description="Number of items per page") |
|
|
|
class TourSummary(BaseModel): |
|
tour_id: Any |
|
title: str |
|
duration: Optional[str] = None |
|
departure_location: Optional[str] = None |
|
destination: Optional[List[str]] = None |
|
region: Optional[str] = None |
|
itinerary: Optional[str] = None |
|
max_participants: Optional[int] = None |
|
departure_id: Optional[Any] = None |
|
start_date: Optional[Union[datetime, date]] = None |
|
price_adult: Optional[float] = None |
|
price_child_120_140: Optional[float] = None |
|
price_child_100_120: Optional[float] = None |
|
promotion_name: Optional[str] = None |
|
promotion_type: Optional[str] = None |
|
promotion_discount: Optional[Any] = None |
|
|
|
class PaginatedTourResponse(BaseModel): |
|
currentPage: int |
|
itemsPerPage: int |
|
totalItems: int |
|
totalPages: int |
|
hasNextPage: bool |
|
hasPrevPage: bool |
|
tours: List[TourSummary] |
|
|
|
class TokenData(BaseModel): |
|
user_id: Optional[int] = None |
|
|
|
async def get_current_user(token: HTTPAuthorizationCredentials = Depends(reusable_oauth2)) -> int: |
|
credentials_exception = HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
try: |
|
payload = jwt.decode(token.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM]) |
|
|
|
user_id: Optional[int] = payload.get("id") |
|
if user_id is None: |
|
user_id = payload.get("userId") |
|
|
|
if user_id is None: |
|
print(f"JWT payload does not contain 'id' or 'userId' field. Payload: {payload}") |
|
raise credentials_exception |
|
|
|
except JWTError as e: |
|
print(f"JWTError: {e}") |
|
raise credentials_exception |
|
except Exception as e: |
|
print(f"An unexpected error occurred during JWT decoding: {e}") |
|
raise credentials_exception |
|
return user_id |
|
|
|
class ChatMessageInput(BaseModel): |
|
message: str = Field(..., description="The text message sent by the user to the chatbot.") |
|
session_id: Optional[str] = Field(None, description="An optional identifier for a specific chat session.") |
|
|
|
class ChatResponseOutput(BaseModel): |
|
user_id: int = Field(..., description="The ID of the user (from JWT token).") |
|
response: str = Field(..., description="The chatbot's generated textual response.") |
|
session_id: Optional[str] = Field(None, description="The session identifier, mirrored if provided in input.") |
|
timestamp: datetime = Field(..., description="UTC timestamp of when the response was generated.") |
|
|
|
def get_db_connection(): |
|
if conn_pool is None: |
|
print("conn_pool is None in get_db_connection. Database module likely not initialized.") |
|
raise HTTPException(status_code=503, detail="Database connection pool not initialized. Check src.database and .env configuration.") |
|
try: |
|
conn = conn_pool.getconn() |
|
yield conn |
|
finally: |
|
if conn: |
|
conn_pool.putconn(conn) |
|
|
|
def fetch_conversation_history(db_conn, user_id: int, session_id: Optional[str] = None) -> List[BaseMessage]: |
|
history: List[BaseMessage] = [] |
|
try: |
|
with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: |
|
if session_id: |
|
cursor.execute( |
|
"SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id = %s ORDER BY interaction_time ASC", |
|
(user_id, session_id) |
|
) |
|
else: |
|
cursor.execute( |
|
"SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id IS NULL ORDER BY interaction_time ASC", |
|
(user_id,) |
|
) |
|
records = cursor.fetchall() |
|
for record in records: |
|
if record["message"]: |
|
history.append(HumanMessage(content=record["message"])) |
|
if record["response"]: |
|
history.append(AIMessage(content=record["response"])) |
|
except Exception as e: |
|
print(f"Error fetching conversation history for user_id {user_id}, session_id {session_id}: {e}") |
|
return history |
|
|
|
def save_interaction_to_history(db_conn, user_id: int, user_message: str, chatbot_response: str, session_id: Optional[str] = None): |
|
try: |
|
with db_conn.cursor() as cursor: |
|
cursor.execute( |
|
"INSERT INTO ChatbotHistory (user_id, message, response, interaction_time, session_id) VALUES (%s, %s, %s, %s, %s)", |
|
(user_id, user_message, chatbot_response, datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")), session_id) |
|
) |
|
db_conn.commit() |
|
except Exception as e: |
|
print(f"Error saving interaction to history for user_id {user_id}, session_id {session_id}: {e}") |
|
db_conn.rollback() |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
try: |
|
if embedding_model: |
|
embedding_model.load_model() |
|
except Exception as e: |
|
print(f"Failed to load embedding model on startup: {str(e)}") |
|
|
|
@app.get("/", include_in_schema=False, response_class=PlainTextResponse) |
|
async def root(): |
|
return "API is running" |
|
|
|
@app.post("/api/chat/", response_model=ChatResponseOutput, tags=["Chat"], summary="Chat with the Travel Chatbot") |
|
async def chat_endpoint(payload: ChatMessageInput, current_user_id: int = Depends(get_current_user), db_conn = Depends(get_db_connection)): |
|
if graph_app is None: |
|
print("graph_app is None in chat_endpoint. Graph_builder module likely not initialized.") |
|
raise HTTPException(status_code=503, detail="Chatbot graph not initialized. Check src.graph_builder.") |
|
|
|
user_id = current_user_id |
|
user_message_content = payload.message |
|
session_id = payload.session_id |
|
|
|
if not session_id: |
|
session_id = str(uuid.uuid4()) |
|
|
|
history = fetch_conversation_history(db_conn, user_id, session_id) |
|
|
|
current_message = HumanMessage(content=user_message_content) |
|
all_messages = history + [current_message] |
|
|
|
inputs = { |
|
"messages": all_messages, |
|
"user_query": user_message_content, |
|
"current_date": None, |
|
"available_locations": None, |
|
"extracted_entities": None, |
|
"search_results": None, |
|
"final_response": None, |
|
"error": None, |
|
"routing_decision": None |
|
} |
|
|
|
full_response_content = "" |
|
try: |
|
result = graph_app.invoke(inputs) |
|
|
|
if isinstance(result, dict) and "final_response" in result: |
|
full_response_content = result["final_response"] |
|
elif isinstance(result, dict) and "messages" in result and result["messages"]: |
|
last_message = result["messages"][-1] |
|
if isinstance(last_message, AIMessage): |
|
full_response_content = last_message.content |
|
|
|
if not full_response_content: |
|
full_response_content = "Sorry, I could not process your request at this moment." |
|
|
|
except Exception as e: |
|
print(f"Error during graph invocation for user_id {user_id}: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error processing message with chatbot: {str(e)}") |
|
|
|
save_interaction_to_history(db_conn, user_id, user_message_content, full_response_content, session_id) |
|
|
|
return ChatResponseOutput( |
|
user_id=user_id, |
|
response=full_response_content, |
|
session_id=session_id, |
|
timestamp=datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")) |
|
) |
|
|
|
@app.post("/api/tours/search", response_model=PaginatedTourResponse, tags=["Tours"], summary="Search for tours based on user query") |
|
async def search_tours_api(request: TourSearchRequest, current_user_id: int = Depends(get_current_user)): |
|
current_date_str = date.today().strftime('%Y-%m-%d') |
|
|
|
entities = extract_entities_tool(user_query=request.query, current_date_str=current_date_str) |
|
|
|
if not entities or (isinstance(entities, dict) and entities.get("error")): |
|
return PaginatedTourResponse( |
|
currentPage=request.page, |
|
itemsPerPage=request.limit, |
|
totalItems=0, |
|
totalPages=0, |
|
hasNextPage=False, |
|
hasPrevPage=False, |
|
tours=[] |
|
) |
|
|
|
all_found_tours = search_tours_tool(entities) |
|
|
|
if not all_found_tours: |
|
return PaginatedTourResponse( |
|
currentPage=request.page, |
|
itemsPerPage=request.limit, |
|
totalItems=0, |
|
totalPages=0, |
|
hasNextPage=False, |
|
hasPrevPage=False, |
|
tours=[] |
|
) |
|
|
|
total_items = len(all_found_tours) |
|
total_pages = (total_items + request.limit - 1) // request.limit |
|
|
|
start_index = (request.page - 1) * request.limit |
|
end_index = start_index + request.limit |
|
paginated_tours_data = all_found_tours[start_index:end_index] |
|
|
|
tour_summaries = [] |
|
for tour_data in paginated_tours_data: |
|
region = tour_data.get("region") |
|
if region is not None and not isinstance(region, str): |
|
region = str(region) |
|
|
|
destination = tour_data.get("destination") |
|
if destination and not isinstance(destination, list): |
|
destination = [str(destination)] |
|
elif destination: |
|
destination = [str(dest) for dest in destination] |
|
|
|
tour_summaries.append(TourSummary( |
|
tour_id=tour_data.get("tour_id"), |
|
title=str(tour_data.get("title", "")), |
|
duration=str(tour_data.get("duration")) if tour_data.get("duration") is not None else None, |
|
departure_location=str(tour_data.get("departure_location")) if tour_data.get("departure_location") is not None else None, |
|
destination=destination, |
|
region=region, |
|
itinerary=str(tour_data.get("itinerary")) if tour_data.get("itinerary") is not None else None, |
|
max_participants=int(tour_data.get("max_participants")) if tour_data.get("max_participants") is not None else None, |
|
departure_id=tour_data.get("departure_id"), |
|
start_date=tour_data.get("start_date"), |
|
price_adult=float(tour_data.get("price_adult")) if tour_data.get("price_adult") is not None else None, |
|
price_child_120_140=float(tour_data.get("price_child_120_140")) if tour_data.get("price_child_120_140") is not None else None, |
|
price_child_100_120=float(tour_data.get("price_child_100_120")) if tour_data.get("price_child_100_120") is not None else None, |
|
promotion_name=str(tour_data.get("promotion_name")) if tour_data.get("promotion_name") is not None else None, |
|
promotion_type=str(tour_data.get("promotion_type")) if tour_data.get("promotion_type") is not None else None, |
|
promotion_discount=tour_data.get("promotion_discount") |
|
)) |
|
|
|
return PaginatedTourResponse( |
|
currentPage=request.page, |
|
itemsPerPage=request.limit, |
|
totalItems=total_items, |
|
totalPages=total_pages, |
|
hasNextPage=(request.page < total_pages), |
|
hasPrevPage=(request.page > 1), |
|
tours=tour_summaries |
|
) |
|
|
|
@app.get("/api/recommendations", response_model=TourRecommendationResponse, tags=["Recommendations"], summary="Get tour recommendations") |
|
async def get_recommendations( |
|
user_id: Optional[int] = None, |
|
tour_id: Optional[int] = None, |
|
limit: int = 3 |
|
): |
|
try: |
|
if limit < 1 or limit > 10: |
|
limit = 3 |
|
recommendations = get_tour_recommendations(user_id, tour_id, limit) |
|
return recommendations |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}") |
|
|
|
@app.post("/api/embed", response_model=EmbeddingResponse, tags=["Embeddings"], summary="Generate text embeddings") |
|
async def get_embedding(request: EmbeddingRequest): |
|
if embedding_model is None: |
|
raise HTTPException( |
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
detail="Embedding service not initialized. Check src.embedding module." |
|
) |
|
|
|
try: |
|
embeddings = embedding_model.get_embedding(request.text) |
|
|
|
return { |
|
"embeddings": embeddings, |
|
"model": embedding_model.model_name, |
|
"dimensions": len(embeddings[0]) |
|
} |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail=f"Error generating embeddings: {str(e)}" |
|
) |