Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, BackgroundTasks, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from email.utils import parsedate_to_datetime | |
| # Import our modules | |
| from scraper import fetch_hazard_tweets, fetch_custom_tweets, get_available_hazards, get_available_locations | |
| from classifier import classify_tweets | |
| from pg_db import init_db, upsert_hazardous_tweet | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Ocean Hazard Detection API", | |
| description="API for detecting ocean hazards from social media posts", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure this properly for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize database | |
| try: | |
| init_db() | |
| logger.info("Database initialized successfully") | |
| except Exception as e: | |
| logger.warning(f"Database initialization failed: {e}. API will work without database persistence.") | |
| # Pydantic models | |
| class TweetAnalysisRequest(BaseModel): | |
| limit: int = 20 | |
| query: Optional[str] = None | |
| hazard_type: Optional[str] = None | |
| location: Optional[str] = None | |
| days_back: int = 1 | |
| class TweetAnalysisResponse(BaseModel): | |
| total_tweets: int | |
| hazardous_tweets: int | |
| results: List[dict] | |
| processing_time: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| message: str | |
| timestamp: str | |
| # Health check endpoint | |
| def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| message="Ocean Hazard Detection API is running", | |
| timestamp=datetime.utcnow().isoformat() | |
| ) | |
| def health(): | |
| """Alternative health check endpoint""" | |
| return health_check() | |
| async def warmup_models(): | |
| """Pre-load all models to reduce first request time""" | |
| try: | |
| logger.info("Starting model warmup...") | |
| # Pre-load all models | |
| from classifier import get_classifier | |
| from ner import get_ner_pipeline | |
| from sentiment import get_emotion_classifier | |
| from translate import get_translator | |
| classifier = get_classifier() | |
| ner = get_ner_pipeline() | |
| emotion_clf = get_emotion_classifier() | |
| translator = get_translator() | |
| # Test with sample data | |
| test_text = "Test tweet for model warmup" | |
| classifier(test_text, ["test", "not test"]) | |
| if ner: | |
| ner(test_text) | |
| emotion_clf(test_text) | |
| translator(test_text) | |
| logger.info("Model warmup completed successfully") | |
| return {"status": "success", "message": "All models loaded and ready"} | |
| except Exception as e: | |
| logger.error(f"Model warmup failed: {str(e)}") | |
| return {"status": "error", "message": str(e)} | |
| # Main analysis endpoint | |
| async def analyze_tweets(request: TweetAnalysisRequest): | |
| """ | |
| Analyze tweets for ocean hazards | |
| - **limit**: Number of tweets to analyze (1-50) | |
| - **query**: Custom search query (optional) | |
| """ | |
| start_time = datetime.utcnow() | |
| try: | |
| logger.info(f"Starting analysis with limit: {request.limit}") | |
| # Fetch tweets based on search type | |
| if request.query: | |
| # Use custom query if provided | |
| from scraper import search_tweets, extract_tweets | |
| result = search_tweets(request.query, limit=request.limit) | |
| tweets = extract_tweets(result) | |
| elif request.hazard_type or request.location: | |
| # Use keyword-based search | |
| tweets = fetch_custom_tweets( | |
| hazard_type=request.hazard_type, | |
| location=request.location, | |
| limit=request.limit, | |
| days_back=request.days_back | |
| ) | |
| else: | |
| # Use default hazard query | |
| tweets = fetch_hazard_tweets(limit=request.limit) | |
| logger.info(f"Fetched {len(tweets)} tweets") | |
| # Classify tweets | |
| results = classify_tweets(tweets) | |
| logger.info(f"Classified {len(results)} tweets") | |
| # Store hazardous tweets in database | |
| hazardous_count = 0 | |
| try: | |
| for r in results: | |
| if r.get('hazardous') == 1: | |
| hazardous_count += 1 | |
| hazards = (r.get('ner') or {}).get('hazards') or [] | |
| hazard_type = ", ".join(hazards) if hazards else "unknown" | |
| locs = (r.get('ner') or {}).get('locations') or [] | |
| if not locs and r.get('location'): | |
| locs = [r['location']] | |
| location = ", ".join(locs) if locs else "unknown" | |
| sentiment = r.get('sentiment') or {"label": "unknown", "score": 0.0} | |
| created_at = r.get('created_at') or "" | |
| tweet_date = "" | |
| tweet_time = "" | |
| if created_at: | |
| dt = None | |
| try: | |
| dt = parsedate_to_datetime(created_at) | |
| except Exception: | |
| dt = None | |
| if dt is None and 'T' in created_at: | |
| try: | |
| iso = created_at.replace('Z', '+00:00') | |
| dt = datetime.fromisoformat(iso) | |
| except Exception: | |
| dt = None | |
| if dt is not None: | |
| tweet_date = dt.date().isoformat() | |
| tweet_time = dt.time().strftime('%H:%M:%S') | |
| upsert_hazardous_tweet( | |
| tweet_url=r.get('tweet_url') or "", | |
| hazard_type=hazard_type, | |
| location=location, | |
| sentiment_label=sentiment.get('label', 'unknown'), | |
| sentiment_score=float(sentiment.get('score', 0.0)), | |
| tweet_date=tweet_date, | |
| tweet_time=tweet_time, | |
| ) | |
| logger.info(f"Stored {hazardous_count} hazardous tweets in database") | |
| except Exception as db_error: | |
| logger.warning(f"Database storage failed: {db_error}. Results will not be persisted.") | |
| # Calculate processing time | |
| processing_time = (datetime.utcnow() - start_time).total_seconds() | |
| return TweetAnalysisResponse( | |
| total_tweets=len(results), | |
| hazardous_tweets=hazardous_count, | |
| results=results, | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Analysis failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Get stored hazardous tweets | |
| async def get_hazardous_tweets(limit: int = 100, offset: int = 0): | |
| """ | |
| Get stored hazardous tweets from database | |
| - **limit**: Maximum number of tweets to return (default: 100) | |
| - **offset**: Number of tweets to skip (default: 0) | |
| """ | |
| try: | |
| from pg_db import get_conn | |
| with get_conn() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| SELECT tweet_url, hazard_type, location, sentiment_label, | |
| sentiment_score, tweet_date, tweet_time, inserted_at | |
| FROM hazardous_tweets | |
| ORDER BY inserted_at DESC | |
| LIMIT %s OFFSET %s | |
| """, (limit, offset)) | |
| columns = [desc[0] for desc in cur.description] | |
| results = [dict(zip(columns, row)) for row in cur.fetchall()] | |
| return { | |
| "tweets": results, | |
| "count": len(results), | |
| "limit": limit, | |
| "offset": offset | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to fetch hazardous tweets: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Get available keywords | |
| async def get_hazard_keywords(): | |
| """Get available hazard types for keyword search""" | |
| return { | |
| "hazards": get_available_hazards(), | |
| "count": len(get_available_hazards()) | |
| } | |
| async def get_location_keywords(): | |
| """Get available locations for keyword search""" | |
| return { | |
| "locations": get_available_locations(), | |
| "count": len(get_available_locations()) | |
| } | |
| # Get statistics | |
| async def get_stats(): | |
| """Get analysis statistics""" | |
| try: | |
| from pg_db import get_conn | |
| with get_conn() as conn: | |
| with conn.cursor() as cur: | |
| # Total hazardous tweets | |
| cur.execute("SELECT COUNT(*) FROM hazardous_tweets") | |
| total_hazardous = cur.fetchone()[0] | |
| # By hazard type | |
| cur.execute(""" | |
| SELECT hazard_type, COUNT(*) as count | |
| FROM hazardous_tweets | |
| GROUP BY hazard_type | |
| ORDER BY count DESC | |
| """) | |
| hazard_types = [{"type": row[0], "count": row[1]} for row in cur.fetchall()] | |
| # By location | |
| cur.execute(""" | |
| SELECT location, COUNT(*) as count | |
| FROM hazardous_tweets | |
| WHERE location != 'unknown' | |
| GROUP BY location | |
| ORDER BY count DESC | |
| LIMIT 10 | |
| """) | |
| locations = [{"location": row[0], "count": row[1]} for row in cur.fetchall()] | |
| # By sentiment | |
| cur.execute(""" | |
| SELECT sentiment_label, COUNT(*) as count | |
| FROM hazardous_tweets | |
| GROUP BY sentiment_label | |
| ORDER BY count DESC | |
| """) | |
| sentiments = [{"sentiment": row[0], "count": row[1]} for row in cur.fetchall()] | |
| return { | |
| "total_hazardous_tweets": total_hazardous, | |
| "hazard_types": hazard_types, | |
| "top_locations": locations, | |
| "sentiment_distribution": sentiments | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to fetch statistics: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |