File size: 14,682 Bytes
a126f1c
 
 
a9c15a1
a126f1c
 
2ea1d2a
dd3dd80
04dc4d1
a126f1c
 
 
 
 
2ea1d2a
510be07
035fa04
 
 
 
 
a126f1c
 
 
 
dd3dd80
a126f1c
 
2ea1d2a
a126f1c
dd3dd80
a126f1c
 
dd3dd80
a126f1c
 
2ea1d2a
a126f1c
 
 
 
 
 
 
 
 
5698bd2
a126f1c
 
 
 
 
 
 
2ea1d2a
 
 
 
 
 
 
 
dd3dd80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a126f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510be07
a126f1c
 
 
510be07
 
 
 
 
 
 
 
 
 
a126f1c
 
 
 
 
 
 
510be07
a126f1c
 
510be07
a126f1c
 
 
510be07
04dc4d1
a126f1c
 
 
510be07
a126f1c
 
5698bd2
 
 
 
 
 
 
 
 
 
 
 
 
 
a126f1c
 
 
 
 
 
 
510be07
 
 
 
a126f1c
510be07
a126f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510be07
a126f1c
 
 
 
510be07
04dc4d1
2ea1d2a
 
5698bd2
dd3dd80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5698bd2
 
 
 
 
 
 
 
 
 
 
 
 
2ea1d2a
5698bd2
2ea1d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5698bd2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import os
import sys
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.responses import PlainTextResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, Union
from datetime import datetime, timezone, timedelta, date
import zoneinfo
import psycopg2
from psycopg2 import pool as psycopg2_pool
from jose import JWTError, jwt
import uvicorn
from dotenv import load_dotenv
import time
import uuid
from src.recommendation_api import (
    TourRecommendationRequest,
    TourRecommendationResponse,
    get_tour_recommendations
)

load_dotenv()

try:
    from src.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM
    from src.database import conn_pool
    from src.graph_builder import graph_app
    from src.embedding import embedding_model
    from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
    from src.tools import extract_entities_tool, search_tours_tool
except ImportError as e:
    print(f"Error importing from src: {e}. Using placeholders. API will likely fail at runtime until this is fixed.")
    DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM = [None]*9
    conn_pool = None
    graph_app = None
    embedding_model = None
    class HumanMessage:
        def __init__(self, content):
            self.content = content
    class AIMessage:
        def __init__(self, content):
            self.content = content
    BaseMessage = Dict

app = FastAPI(
    title="Travel Chatbot, Tour Search & Recommendation API",
    version="1.0.0"
)

reusable_oauth2 = HTTPBearer(
    scheme_name="Bearer"
)

class EmbeddingRequest(BaseModel):
    text: Union[str, List[str]]

class EmbeddingResponse(BaseModel):
    embeddings: List[List[float]]
    model: str
    dimensions: int

class TourSearchRequest(BaseModel):
    query: str = Field(..., description="The search query for tours, e.g., 'tôi muốn đi đà nẵng'")
    page: int = Field(1, ge=1, description="Current page number for pagination")
    limit: int = Field(10, ge=1, le=100, description="Number of items per page")

class TourSummary(BaseModel):
    tour_id: Any
    title: str
    duration: Optional[str] = None
    departure_location: Optional[str] = None
    destination: Optional[List[str]] = None
    region: Optional[str] = None
    itinerary: Optional[str] = None
    max_participants: Optional[int] = None
    departure_id: Optional[Any] = None
    start_date: Optional[Union[datetime, date]] = None
    price_adult: Optional[float] = None
    price_child_120_140: Optional[float] = None
    price_child_100_120: Optional[float] = None
    promotion_name: Optional[str] = None
    promotion_type: Optional[str] = None
    promotion_discount: Optional[Any] = None

class PaginatedTourResponse(BaseModel):
    currentPage: int
    itemsPerPage: int
    totalItems: int
    totalPages: int
    hasNextPage: bool
    hasPrevPage: bool
    tours: List[TourSummary]

class TokenData(BaseModel):
    user_id: Optional[int] = None

async def get_current_user(token: HTTPAuthorizationCredentials = Depends(reusable_oauth2)) -> int:
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])

        user_id: Optional[int] = payload.get("id")
        if user_id is None:
            user_id = payload.get("userId")

        if user_id is None:
            print(f"JWT payload does not contain 'id' or 'userId' field. Payload: {payload}")
            raise credentials_exception

    except JWTError as e:
        print(f"JWTError: {e}")
        raise credentials_exception
    except Exception as e:
        print(f"An unexpected error occurred during JWT decoding: {e}")
        raise credentials_exception
    return user_id

class ChatMessageInput(BaseModel):
    message: str = Field(..., description="The text message sent by the user to the chatbot.")
    session_id: Optional[str] = Field(None, description="An optional identifier for a specific chat session.")

class ChatResponseOutput(BaseModel):
    user_id: int = Field(..., description="The ID of the user (from JWT token).")
    response: str = Field(..., description="The chatbot's generated textual response.")
    session_id: Optional[str] = Field(None, description="The session identifier, mirrored if provided in input.")
    timestamp: datetime = Field(..., description="UTC timestamp of when the response was generated.")

def get_db_connection():
    if conn_pool is None:
        print("conn_pool is None in get_db_connection. Database module likely not initialized.")
        raise HTTPException(status_code=503, detail="Database connection pool not initialized. Check src.database and .env configuration.")
    try:
        conn = conn_pool.getconn()
        yield conn
    finally:
        if conn:
            conn_pool.putconn(conn)

def fetch_conversation_history(db_conn, user_id: int, session_id: Optional[str] = None) -> List[BaseMessage]:
    history: List[BaseMessage] = []
    try:
        with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
            if session_id:
                cursor.execute(
                    "SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id = %s ORDER BY interaction_time ASC",
                    (user_id, session_id)
                )
            else:
                cursor.execute(
                    "SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id IS NULL ORDER BY interaction_time ASC",
                    (user_id,)
                )
            records = cursor.fetchall()
            for record in records:
                if record["message"]:
                    history.append(HumanMessage(content=record["message"]))
                if record["response"]:
                    history.append(AIMessage(content=record["response"]))
    except Exception as e:
        print(f"Error fetching conversation history for user_id {user_id}, session_id {session_id}: {e}")
    return history

def save_interaction_to_history(db_conn, user_id: int, user_message: str, chatbot_response: str, session_id: Optional[str] = None):
    try:
        with db_conn.cursor() as cursor:
            cursor.execute(
                "INSERT INTO ChatbotHistory (user_id, message, response, interaction_time, session_id) VALUES (%s, %s, %s, %s, %s)",
                (user_id, user_message, chatbot_response, datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")), session_id)
            )
        db_conn.commit()
    except Exception as e:
        print(f"Error saving interaction to history for user_id {user_id}, session_id {session_id}: {e}")
        db_conn.rollback()


@app.on_event("startup")
async def startup_event():
    try:
        if embedding_model:
            embedding_model.load_model()
    except Exception as e:
        print(f"Failed to load embedding model on startup: {str(e)}")

@app.get("/", include_in_schema=False, response_class=PlainTextResponse)
async def root():
    return "API is running"

@app.post("/api/chat/", response_model=ChatResponseOutput, tags=["Chat"], summary="Chat with the Travel Chatbot")
async def chat_endpoint(payload: ChatMessageInput, current_user_id: int = Depends(get_current_user), db_conn = Depends(get_db_connection)):
    if graph_app is None:
        print("graph_app is None in chat_endpoint. Graph_builder module likely not initialized.")
        raise HTTPException(status_code=503, detail="Chatbot graph not initialized. Check src.graph_builder.")

    user_id = current_user_id
    user_message_content = payload.message
    session_id = payload.session_id
    
    if not session_id:
        session_id = str(uuid.uuid4())

    history = fetch_conversation_history(db_conn, user_id, session_id)
    
    current_message = HumanMessage(content=user_message_content)
    all_messages = history + [current_message]
    
    inputs = {
        "messages": all_messages,
        "user_query": user_message_content,
        "current_date": None,
        "available_locations": None,
        "extracted_entities": None,
        "search_results": None,
        "final_response": None,
        "error": None,
        "routing_decision": None
    }

    full_response_content = ""
    try:
        result = graph_app.invoke(inputs)
        
        if isinstance(result, dict) and "final_response" in result:
            full_response_content = result["final_response"]
        elif isinstance(result, dict) and "messages" in result and result["messages"]:
            last_message = result["messages"][-1]
            if isinstance(last_message, AIMessage):
                full_response_content = last_message.content
        
        if not full_response_content:
            full_response_content = "Sorry, I could not process your request at this moment."

    except Exception as e:
        print(f"Error during graph invocation for user_id {user_id}: {e}")
        raise HTTPException(status_code=500, detail=f"Error processing message with chatbot: {str(e)}")

    save_interaction_to_history(db_conn, user_id, user_message_content, full_response_content, session_id)

    return ChatResponseOutput(
        user_id=user_id,
        response=full_response_content,
        session_id=session_id,
        timestamp=datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok"))
    )

@app.post("/api/tours/search", response_model=PaginatedTourResponse, tags=["Tours"], summary="Search for tours based on user query")
async def search_tours_api(request: TourSearchRequest, current_user_id: int = Depends(get_current_user)):
    current_date_str = date.today().strftime('%Y-%m-%d')

    entities = extract_entities_tool(user_query=request.query, current_date_str=current_date_str)

    if not entities or (isinstance(entities, dict) and entities.get("error")):
        return PaginatedTourResponse(
            currentPage=request.page,
            itemsPerPage=request.limit,
            totalItems=0,
            totalPages=0,
            hasNextPage=False,
            hasPrevPage=False,
            tours=[]
        )

    all_found_tours = search_tours_tool(entities)

    if not all_found_tours:
        return PaginatedTourResponse(
            currentPage=request.page,
            itemsPerPage=request.limit,
            totalItems=0,
            totalPages=0,
            hasNextPage=False,
            hasPrevPage=False,
            tours=[]
        )

    total_items = len(all_found_tours)
    total_pages = (total_items + request.limit - 1) // request.limit

    start_index = (request.page - 1) * request.limit
    end_index = start_index + request.limit
    paginated_tours_data = all_found_tours[start_index:end_index]

    tour_summaries = []
    for tour_data in paginated_tours_data:
        region = tour_data.get("region")
        if region is not None and not isinstance(region, str):
            region = str(region)
        
        destination = tour_data.get("destination")
        if destination and not isinstance(destination, list):
            destination = [str(destination)]
        elif destination:
            destination = [str(dest) for dest in destination]
        
        tour_summaries.append(TourSummary(
            tour_id=tour_data.get("tour_id"),
            title=str(tour_data.get("title", "")),
            duration=str(tour_data.get("duration")) if tour_data.get("duration") is not None else None,
            departure_location=str(tour_data.get("departure_location")) if tour_data.get("departure_location") is not None else None,
            destination=destination,
            region=region,
            itinerary=str(tour_data.get("itinerary")) if tour_data.get("itinerary") is not None else None,
            max_participants=int(tour_data.get("max_participants")) if tour_data.get("max_participants") is not None else None,
            departure_id=tour_data.get("departure_id"),
            start_date=tour_data.get("start_date"),
            price_adult=float(tour_data.get("price_adult")) if tour_data.get("price_adult") is not None else None,
            price_child_120_140=float(tour_data.get("price_child_120_140")) if tour_data.get("price_child_120_140") is not None else None,
            price_child_100_120=float(tour_data.get("price_child_100_120")) if tour_data.get("price_child_100_120") is not None else None,
            promotion_name=str(tour_data.get("promotion_name")) if tour_data.get("promotion_name") is not None else None,
            promotion_type=str(tour_data.get("promotion_type")) if tour_data.get("promotion_type") is not None else None,
            promotion_discount=tour_data.get("promotion_discount")
        ))

    return PaginatedTourResponse(
        currentPage=request.page,
        itemsPerPage=request.limit,
        totalItems=total_items,
        totalPages=total_pages,
        hasNextPage=(request.page < total_pages),
        hasPrevPage=(request.page > 1),
        tours=tour_summaries
    )

@app.get("/api/recommendations", response_model=TourRecommendationResponse, tags=["Recommendations"], summary="Get tour recommendations")
async def get_recommendations(
    user_id: Optional[int] = None,
    tour_id: Optional[int] = None,
    limit: int = 3
):
    try:
        if limit < 1 or limit > 10:
            limit = 3
        recommendations = get_tour_recommendations(user_id, tour_id, limit)
        return recommendations
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}")

@app.post("/api/embed", response_model=EmbeddingResponse, tags=["Embeddings"], summary="Generate text embeddings")
async def get_embedding(request: EmbeddingRequest):
    if embedding_model is None:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="Embedding service not initialized. Check src.embedding module."
        )

    try:
        embeddings = embedding_model.get_embedding(request.text)

        return {
            "embeddings": embeddings,
            "model": embedding_model.model_name,
            "dimensions": len(embeddings[0])
        }
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error generating embeddings: {str(e)}"
        )