File size: 14,682 Bytes
a126f1c a9c15a1 a126f1c 2ea1d2a dd3dd80 04dc4d1 a126f1c 2ea1d2a 510be07 035fa04 a126f1c dd3dd80 a126f1c 2ea1d2a a126f1c dd3dd80 a126f1c dd3dd80 a126f1c 2ea1d2a a126f1c 5698bd2 a126f1c 2ea1d2a dd3dd80 a126f1c 510be07 a126f1c 510be07 a126f1c 510be07 a126f1c 510be07 a126f1c 510be07 04dc4d1 a126f1c 510be07 a126f1c 5698bd2 a126f1c 510be07 a126f1c 510be07 a126f1c 510be07 a126f1c 510be07 04dc4d1 2ea1d2a 5698bd2 dd3dd80 5698bd2 2ea1d2a 5698bd2 2ea1d2a 5698bd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
import os
import sys
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.responses import PlainTextResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, Union
from datetime import datetime, timezone, timedelta, date
import zoneinfo
import psycopg2
from psycopg2 import pool as psycopg2_pool
from jose import JWTError, jwt
import uvicorn
from dotenv import load_dotenv
import time
import uuid
from src.recommendation_api import (
TourRecommendationRequest,
TourRecommendationResponse,
get_tour_recommendations
)
load_dotenv()
try:
from src.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM
from src.database import conn_pool
from src.graph_builder import graph_app
from src.embedding import embedding_model
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from src.tools import extract_entities_tool, search_tours_tool
except ImportError as e:
print(f"Error importing from src: {e}. Using placeholders. API will likely fail at runtime until this is fixed.")
DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM = [None]*9
conn_pool = None
graph_app = None
embedding_model = None
class HumanMessage:
def __init__(self, content):
self.content = content
class AIMessage:
def __init__(self, content):
self.content = content
BaseMessage = Dict
app = FastAPI(
title="Travel Chatbot, Tour Search & Recommendation API",
version="1.0.0"
)
reusable_oauth2 = HTTPBearer(
scheme_name="Bearer"
)
class EmbeddingRequest(BaseModel):
text: Union[str, List[str]]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
model: str
dimensions: int
class TourSearchRequest(BaseModel):
query: str = Field(..., description="The search query for tours, e.g., 'tôi muốn đi đà nẵng'")
page: int = Field(1, ge=1, description="Current page number for pagination")
limit: int = Field(10, ge=1, le=100, description="Number of items per page")
class TourSummary(BaseModel):
tour_id: Any
title: str
duration: Optional[str] = None
departure_location: Optional[str] = None
destination: Optional[List[str]] = None
region: Optional[str] = None
itinerary: Optional[str] = None
max_participants: Optional[int] = None
departure_id: Optional[Any] = None
start_date: Optional[Union[datetime, date]] = None
price_adult: Optional[float] = None
price_child_120_140: Optional[float] = None
price_child_100_120: Optional[float] = None
promotion_name: Optional[str] = None
promotion_type: Optional[str] = None
promotion_discount: Optional[Any] = None
class PaginatedTourResponse(BaseModel):
currentPage: int
itemsPerPage: int
totalItems: int
totalPages: int
hasNextPage: bool
hasPrevPage: bool
tours: List[TourSummary]
class TokenData(BaseModel):
user_id: Optional[int] = None
async def get_current_user(token: HTTPAuthorizationCredentials = Depends(reusable_oauth2)) -> int:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
user_id: Optional[int] = payload.get("id")
if user_id is None:
user_id = payload.get("userId")
if user_id is None:
print(f"JWT payload does not contain 'id' or 'userId' field. Payload: {payload}")
raise credentials_exception
except JWTError as e:
print(f"JWTError: {e}")
raise credentials_exception
except Exception as e:
print(f"An unexpected error occurred during JWT decoding: {e}")
raise credentials_exception
return user_id
class ChatMessageInput(BaseModel):
message: str = Field(..., description="The text message sent by the user to the chatbot.")
session_id: Optional[str] = Field(None, description="An optional identifier for a specific chat session.")
class ChatResponseOutput(BaseModel):
user_id: int = Field(..., description="The ID of the user (from JWT token).")
response: str = Field(..., description="The chatbot's generated textual response.")
session_id: Optional[str] = Field(None, description="The session identifier, mirrored if provided in input.")
timestamp: datetime = Field(..., description="UTC timestamp of when the response was generated.")
def get_db_connection():
if conn_pool is None:
print("conn_pool is None in get_db_connection. Database module likely not initialized.")
raise HTTPException(status_code=503, detail="Database connection pool not initialized. Check src.database and .env configuration.")
try:
conn = conn_pool.getconn()
yield conn
finally:
if conn:
conn_pool.putconn(conn)
def fetch_conversation_history(db_conn, user_id: int, session_id: Optional[str] = None) -> List[BaseMessage]:
history: List[BaseMessage] = []
try:
with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
if session_id:
cursor.execute(
"SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id = %s ORDER BY interaction_time ASC",
(user_id, session_id)
)
else:
cursor.execute(
"SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id IS NULL ORDER BY interaction_time ASC",
(user_id,)
)
records = cursor.fetchall()
for record in records:
if record["message"]:
history.append(HumanMessage(content=record["message"]))
if record["response"]:
history.append(AIMessage(content=record["response"]))
except Exception as e:
print(f"Error fetching conversation history for user_id {user_id}, session_id {session_id}: {e}")
return history
def save_interaction_to_history(db_conn, user_id: int, user_message: str, chatbot_response: str, session_id: Optional[str] = None):
try:
with db_conn.cursor() as cursor:
cursor.execute(
"INSERT INTO ChatbotHistory (user_id, message, response, interaction_time, session_id) VALUES (%s, %s, %s, %s, %s)",
(user_id, user_message, chatbot_response, datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")), session_id)
)
db_conn.commit()
except Exception as e:
print(f"Error saving interaction to history for user_id {user_id}, session_id {session_id}: {e}")
db_conn.rollback()
@app.on_event("startup")
async def startup_event():
try:
if embedding_model:
embedding_model.load_model()
except Exception as e:
print(f"Failed to load embedding model on startup: {str(e)}")
@app.get("/", include_in_schema=False, response_class=PlainTextResponse)
async def root():
return "API is running"
@app.post("/api/chat/", response_model=ChatResponseOutput, tags=["Chat"], summary="Chat with the Travel Chatbot")
async def chat_endpoint(payload: ChatMessageInput, current_user_id: int = Depends(get_current_user), db_conn = Depends(get_db_connection)):
if graph_app is None:
print("graph_app is None in chat_endpoint. Graph_builder module likely not initialized.")
raise HTTPException(status_code=503, detail="Chatbot graph not initialized. Check src.graph_builder.")
user_id = current_user_id
user_message_content = payload.message
session_id = payload.session_id
if not session_id:
session_id = str(uuid.uuid4())
history = fetch_conversation_history(db_conn, user_id, session_id)
current_message = HumanMessage(content=user_message_content)
all_messages = history + [current_message]
inputs = {
"messages": all_messages,
"user_query": user_message_content,
"current_date": None,
"available_locations": None,
"extracted_entities": None,
"search_results": None,
"final_response": None,
"error": None,
"routing_decision": None
}
full_response_content = ""
try:
result = graph_app.invoke(inputs)
if isinstance(result, dict) and "final_response" in result:
full_response_content = result["final_response"]
elif isinstance(result, dict) and "messages" in result and result["messages"]:
last_message = result["messages"][-1]
if isinstance(last_message, AIMessage):
full_response_content = last_message.content
if not full_response_content:
full_response_content = "Sorry, I could not process your request at this moment."
except Exception as e:
print(f"Error during graph invocation for user_id {user_id}: {e}")
raise HTTPException(status_code=500, detail=f"Error processing message with chatbot: {str(e)}")
save_interaction_to_history(db_conn, user_id, user_message_content, full_response_content, session_id)
return ChatResponseOutput(
user_id=user_id,
response=full_response_content,
session_id=session_id,
timestamp=datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok"))
)
@app.post("/api/tours/search", response_model=PaginatedTourResponse, tags=["Tours"], summary="Search for tours based on user query")
async def search_tours_api(request: TourSearchRequest, current_user_id: int = Depends(get_current_user)):
current_date_str = date.today().strftime('%Y-%m-%d')
entities = extract_entities_tool(user_query=request.query, current_date_str=current_date_str)
if not entities or (isinstance(entities, dict) and entities.get("error")):
return PaginatedTourResponse(
currentPage=request.page,
itemsPerPage=request.limit,
totalItems=0,
totalPages=0,
hasNextPage=False,
hasPrevPage=False,
tours=[]
)
all_found_tours = search_tours_tool(entities)
if not all_found_tours:
return PaginatedTourResponse(
currentPage=request.page,
itemsPerPage=request.limit,
totalItems=0,
totalPages=0,
hasNextPage=False,
hasPrevPage=False,
tours=[]
)
total_items = len(all_found_tours)
total_pages = (total_items + request.limit - 1) // request.limit
start_index = (request.page - 1) * request.limit
end_index = start_index + request.limit
paginated_tours_data = all_found_tours[start_index:end_index]
tour_summaries = []
for tour_data in paginated_tours_data:
region = tour_data.get("region")
if region is not None and not isinstance(region, str):
region = str(region)
destination = tour_data.get("destination")
if destination and not isinstance(destination, list):
destination = [str(destination)]
elif destination:
destination = [str(dest) for dest in destination]
tour_summaries.append(TourSummary(
tour_id=tour_data.get("tour_id"),
title=str(tour_data.get("title", "")),
duration=str(tour_data.get("duration")) if tour_data.get("duration") is not None else None,
departure_location=str(tour_data.get("departure_location")) if tour_data.get("departure_location") is not None else None,
destination=destination,
region=region,
itinerary=str(tour_data.get("itinerary")) if tour_data.get("itinerary") is not None else None,
max_participants=int(tour_data.get("max_participants")) if tour_data.get("max_participants") is not None else None,
departure_id=tour_data.get("departure_id"),
start_date=tour_data.get("start_date"),
price_adult=float(tour_data.get("price_adult")) if tour_data.get("price_adult") is not None else None,
price_child_120_140=float(tour_data.get("price_child_120_140")) if tour_data.get("price_child_120_140") is not None else None,
price_child_100_120=float(tour_data.get("price_child_100_120")) if tour_data.get("price_child_100_120") is not None else None,
promotion_name=str(tour_data.get("promotion_name")) if tour_data.get("promotion_name") is not None else None,
promotion_type=str(tour_data.get("promotion_type")) if tour_data.get("promotion_type") is not None else None,
promotion_discount=tour_data.get("promotion_discount")
))
return PaginatedTourResponse(
currentPage=request.page,
itemsPerPage=request.limit,
totalItems=total_items,
totalPages=total_pages,
hasNextPage=(request.page < total_pages),
hasPrevPage=(request.page > 1),
tours=tour_summaries
)
@app.get("/api/recommendations", response_model=TourRecommendationResponse, tags=["Recommendations"], summary="Get tour recommendations")
async def get_recommendations(
user_id: Optional[int] = None,
tour_id: Optional[int] = None,
limit: int = 3
):
try:
if limit < 1 or limit > 10:
limit = 3
recommendations = get_tour_recommendations(user_id, tour_id, limit)
return recommendations
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}")
@app.post("/api/embed", response_model=EmbeddingResponse, tags=["Embeddings"], summary="Generate text embeddings")
async def get_embedding(request: EmbeddingRequest):
if embedding_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Embedding service not initialized. Check src.embedding module."
)
try:
embeddings = embedding_model.get_embedding(request.text)
return {
"embeddings": embeddings,
"model": embedding_model.model_name,
"dimensions": len(embeddings[0])
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error generating embeddings: {str(e)}"
) |