from fastapi import FastAPI, Request, status from fastapi.responses import JSONResponse from fastapi.responses import Response from fastapi.exceptions import HTTPException from fastapi.background import BackgroundTasks from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from slowapi.middleware import SlowAPIMiddleware from typing import Dict, List from prometheus_client import Counter, Histogram, start_http_server from pydantic import BaseModel, ValidationError from app.services.message import generate_reply, send_reply import logging import httpx from datetime import datetime from sentence_transformers import SentenceTransformer from app.search.rag_pipeline import RAGSystem from contextlib import asynccontextmanager # from app.db.database import create_indexes, init_db from app.services.webhook_handler import verify_webhook from app.handlers.message_handler import MessageHandler from app.handlers.webhook_handler import WebhookHandler from app.handlers.media_handler import WhatsAppMediaHandler from app.services.cache import MessageCache from app.services.chat_manager import ChatManager from app.api.api_prompt import prompt_router from app.api.api_file import file_router, load_file_with_markdown_function from app.utils.load_env import ACCESS_TOKEN, WHATSAPP_API_URL, GEMINI_API from markitdown import MarkItDown # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Initialize handlers at startup message_handler = None webhook_handler = None indexed_links = ["https://sswalfa.surabaya.go.id/info/detail/izin-pengumpulan-sumbangan-bencana", "https://sswalfa.surabaya.go.id/info/detail/izin-pemakaian-ruang-terbuka-hijau", "https://sswalfa.surabaya.go.id/info/detail/pengganti-ipt", "https://sswalfa.surabaya.go.id/info/detail/arahan-sistem-drainase", "https://sswalfa.surabaya.go.id/info/detail/rangkaian-pelayanan-surat-pernyataan-belum-menikah-lagi-bagi-jandaduda" ] async def setup_message_handler(): logger = logging.getLogger(__name__) message_cache = MessageCache() chat_manager = ChatManager() media_handler = WhatsAppMediaHandler() return MessageHandler( message_cache=message_cache, chat_manager=chat_manager, media_handler=media_handler, logger=logger ) async def setup_rag_system(): embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Replace with your model if different rag_system = RAGSystem(embedding_model) return rag_system # Initialize FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): try: # await init_db() logger.info("Connected to the MongoDB database!") rag_system = await setup_rag_system() app.state.rag_system = rag_system global message_handler, webhook_handler message_handler = await setup_message_handler() webhook_handler = WebhookHandler(message_handler) # collections = app.database.list_collection_names() # print(f"Collections in {db_name}: {collections}") await load_file_with_markdown_function(rag_system=rag_system, filepaths=indexed_links) yield except Exception as e: logger.error(e) # Initialize Limiter and Prometheus Metrics limiter = Limiter(key_func=get_remote_address) app = FastAPI(lifespan=lifespan) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # Add SlowAPI Middleware app.add_middleware(SlowAPIMiddleware) # app.include_router(users.router, prefix="/users", tags=["Users"]) app.include_router(prompt_router, prefix="/prompts", tags=["Prompts"]) app.include_router(file_router, prefix="/file_load", tags=["File Load"]) # Prometheus metrics webhook_requests = Counter('webhook_requests_total', 'Total webhook requests') webhook_processing_time = Histogram('webhook_processing_seconds', 'Time spent processing webhook') # Start Prometheus metrics server on port 8002 # start_http_server(8002) # Register webhook routes # app.post("/webhook")(webhook) # Define Pydantic schema for request validation class WebhookPayload(BaseModel): entry: List[Dict] @app.post("/webhook") # @limiter.limit("20/minute") async def webhook(request: Request, background_tasks: BackgroundTasks): try: payload = await request.json() rag_system = request.app.state.rag_system # validated_payload = WebhookPayload(**payload) # Validate payload # logger.info(f"Validated Payload: {validated_payload}") # Process the webhook payload here # For example: # results = process_webhook_entries(validated_payload.entry) # e.g., whatsapp_token, verify_token, llm_api_key, llm_model whatsapp_token = request.query_params.get("whatsapp_token") whatsapp_url = request.query_params.get("whatsapp_url") gemini_api = request.query_params.get("gemini_api") llm_model = request.query_params.get("cx_code") # Return HTTP 200 immediately # response = JSONResponse( # content={"status": "received"}, # status_code=200 # ) print(f"payload: {payload}") # response = await webhook_handler.process_webhook( # payload=payload, # whatsapp_token=ACCESS_TOKEN, # whatsapp_url=WHATSAPP_API_URL, # gemini_api=GEMINI_API, # rag_system=rag_system, # ) # Add the processing to background tasks background_tasks.add_task( webhook_handler.process_webhook, payload=payload, whatsapp_token=ACCESS_TOKEN, whatsapp_url=WHATSAPP_API_URL, gemini_api=GEMINI_API, rag_system=rag_system, ) # Return HTTP 200 immediately return JSONResponse( content={"status": "received"}, status_code=status.HTTP_200_OK ) # return JSONResponse( # content=response.__dict__, # status_code=status.HTTP_200_OK # ) except ValidationError as ve: logger.error(f"Validation error: {ve}") return JSONResponse( content={"status": "error", "detail": ve.errors()}, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY ) except Exception as e: logger.error(f"Unexpected error: {str(e)}") return JSONResponse( content={"status": "error", "detail": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) app.get("/webhook")(verify_webhook) @app.post("/load_file") async def load_file_with_markitdown(file_path:str, llm_client:str=None, model:str=None): if llm_client and model: markitdown = MarkItDown(llm_client, model) documents = markitdown.convert(file_path) else: markitdown = MarkItDown() documents = markitdown.convert(file_path) print(f"documents: {documents}") return documents # Add a route for Prometheus metrics (optional, if not using a separate Prometheus server) @app.get("/metrics") async def metrics(): from prometheus_client import generate_latest return Response(content=generate_latest(), media_type="text/plain") # In-memory cache with timestamp cleanup # class MessageCache: # def __init__(self, max_age_hours: int = 24): # self.messages: Dict[str, float] = {} # self.max_age_seconds = max_age_hours * 3600 # def add(self, message_id: str) -> None: # self.cleanup() # self.messages[message_id] = time.time() # def exists(self, message_id: str) -> bool: # self.cleanup() # return message_id in self.messages # def cleanup(self) -> None: # current_time = time.time() # self.messages = { # msg_id: timestamp # for msg_id, timestamp in self.messages.items() # if current_time - timestamp < self.max_age_seconds # } # message_cache = MessageCache() # user_chats = {} # @app.post("/webhook") # async def webhook(request: Request): # request_id = f"req_{int(time.time()*1000)}" # logger.info(f"Processing webhook request {request_id}") # payload = await request.json() # print("Webhook received:", payload) # processed_count = 0 # error_count = 0 # results = [] # entries = payload.get("entry", []) # for entry in entries: # entry_id = entry.get("id") # logger.info(f"Processing entry_id: {entry_id}") # changes = entry.get("changes", []) # for change in changes: # messages = change.get("value", {}).get("messages", []) # for message in messages: # message_id = message.get("id") # timestamp = message.get("timestamp") # content = message.get("text", {}).get("body") # sender_id = message.get("from") # msg_type = message.get('type') # # Deduplicate messages based on message_id # if message_cache.exists(message_id): # logger.info(f"Duplicate message detected and skipped: {message_id}") # continue # if sender_id not in user_chats: # user_chats[sender_id] = [] # user_chats[sender_id].append({ # "role": "user", # "content": content # }) # history = "".join([f"{item['role']}: {item['content']}\n" for item in user_chats[sender_id]]) # print(f"history: {history}") # try: # # Process message with retry logic # result = await process_message_with_retry( # sender_id,content, # history, # timestamp, # ) # user_chats[sender_id].append({ # "role": "assistant", # "content": result # }) # # Add the message ID to the cache # message_cache.add(message_id) # processed_count += 1 # results.append(result) # except Exception as e: # error_count += 1 # logger.error( # f"Failed to process message {message_id}: {str(e)}", # exc_info=True # ) # results.append({ # "status": "error", # "message_id": message_id, # "error": str(e) # }) # response_data = { # "request_id": request_id, # "processed": processed_count, # "errors": error_count, # "results": results # } # logger.info( # f"Webhook processing completed - " # f"Processed: {processed_count}, Errors: {error_count}" # ) # return JSONResponse( # content=response_data, # status_code=status.HTTP_200_OK # ) # @app.get("/webhook") # async def verify_webhook(request: Request): # mode = request.query_params.get('hub.mode') # token = request.query_params.get('hub.verify_token') # challenge = request.query_params.get('hub.challenge') # # Replace 'your_verification_token' with the token you set in Facebook Business Manager # if mode == 'subscribe' and token == 'test': # # Return the challenge as plain text # return Response(content=challenge, media_type="text/plain") # else: # raise HTTPException(status_code=403, detail="Verification failed")