Spaces:
Running
Running
| """ | |
| FastAPI application with /api prefix for all endpoints. | |
| Endpoints: | |
| - GET /api/analysis: Current analysis report | |
| - GET /api/history: Historical price and sentiment data | |
| - GET /api/health: System health check | |
| """ | |
| import logging | |
| # Suppress httpx request logging to prevent API keys in URLs from appearing in logs | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logging.getLogger("httpcore").setLevel(logging.WARNING) | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime, timedelta, timezone | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect, Depends, Header, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from sqlalchemy import func | |
| from app.db import init_db, SessionLocal, get_db_type | |
| from app.models import NewsArticle, PriceBar, DailySentiment, DailySentimentV2, AnalysisSnapshot, NewsSentimentV2, NewsProcessed, NewsRaw | |
| from app.settings import get_settings | |
| from app.lock import is_pipeline_locked | |
| # NOTE: Faz 1 - API is snapshot-only, no report generation | |
| # generate_analysis_report and save_analysis_snapshot are now worker-only | |
| from app.schemas import ( | |
| AnalysisReport, | |
| HistoryResponse, | |
| HistoryDataPoint, | |
| HealthResponse, | |
| ErrorResponse, | |
| ConsensusSignal, | |
| TFTModelSummaryResponse, | |
| BacktestReportResponse, | |
| NewsItem, | |
| NewsListResponse, | |
| NewsStatsResponse, | |
| NewsFinbertProbs, | |
| NewsSentimentBlock, | |
| ) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # Lifespan Management | |
| # ============================================================================= | |
| async def lifespan(app: FastAPI): | |
| """Application startup and shutdown events.""" | |
| # Startup | |
| logger.info("Starting CopperMind API...") | |
| init_db() | |
| logger.info("Database initialized") | |
| # NOTE: Scheduler is NO LONGER started here. | |
| # Pipeline scheduling is now external (GitHub Actions cron). | |
| # This API only reads data and enqueues jobs. | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down CopperMind API...") | |
| # Close Redis pool if initialized | |
| try: | |
| from adapters.queue.redis import close_redis_pool | |
| import asyncio | |
| asyncio.create_task(close_redis_pool()) | |
| except ImportError: | |
| pass | |
| # ============================================================================= | |
| # FastAPI Application | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="CopperMind API", | |
| description="Copper market sentiment analysis and price prediction API", | |
| version="1.0.0", | |
| docs_url="/api/docs", | |
| redoc_url="/api/redoc", | |
| openapi_url="/api/openapi.json", | |
| lifespan=lifespan, | |
| ) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, restrict this | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================================= | |
| # API Endpoints | |
| # ============================================================================= | |
| async def get_analysis( | |
| symbol: str = Query(default="HG=F", description="Trading symbol") | |
| ): | |
| """ | |
| Get current analysis report. | |
| SNAPSHOT-ONLY MODE (Faz 1): | |
| - Reads the latest snapshot from database | |
| - NO yfinance calls | |
| - NO model loading | |
| - NO feature building | |
| - All heavy computation is done by the worker pipeline | |
| Response includes quality_state: | |
| - "ok": Fresh snapshot available | |
| - "stale": Snapshot older than 36 hours | |
| - "missing": No snapshot found | |
| """ | |
| STALE_THRESHOLD_HOURS = 36 | |
| with SessionLocal() as session: | |
| # Get latest snapshot - any age | |
| snapshot = session.query(AnalysisSnapshot).filter( | |
| AnalysisSnapshot.symbol == symbol | |
| ).order_by(AnalysisSnapshot.generated_at.desc()).first() | |
| if snapshot is None: | |
| # No snapshot at all - return minimal response for UI compatibility | |
| logger.warning(f"No snapshot found for {symbol}") | |
| return { | |
| "symbol": symbol, | |
| "quality_state": "missing", | |
| "model_state": "offline", | |
| "current_price": 0.0, | |
| "predicted_return": 0.0, | |
| "predicted_price": 0.0, | |
| "confidence_lower": 0.0, | |
| "confidence_upper": 0.0, | |
| "sentiment_index": 0.0, | |
| "sentiment_label": "Neutral", | |
| "top_influencers": [], | |
| "data_quality": { | |
| "news_count_7d": 0, | |
| "missing_days": 0, | |
| "coverage_pct": 0, | |
| }, | |
| "generated_at": None, | |
| "message": "No analysis available. Pipeline may not have run yet.", | |
| } | |
| # Calculate snapshot age | |
| now = datetime.now(timezone.utc) | |
| generated_at = snapshot.generated_at | |
| if generated_at.tzinfo is None: | |
| generated_at = generated_at.replace(tzinfo=timezone.utc) | |
| age_hours = (now - generated_at).total_seconds() / 3600 | |
| # Determine quality state | |
| if age_hours > STALE_THRESHOLD_HOURS: | |
| quality_state = "stale" | |
| else: | |
| quality_state = "ok" | |
| # Build response from snapshot | |
| report = snapshot.report_json.copy() if snapshot.report_json else {} | |
| # Add/override metadata | |
| report["quality_state"] = quality_state | |
| report["model_state"] = "ok" if quality_state == "ok" else "degraded" | |
| report["snapshot_age_hours"] = round(age_hours, 1) | |
| report["generated_at"] = generated_at.isoformat() | |
| # Ensure required fields exist (backward compatibility) | |
| if "symbol" not in report: | |
| report["symbol"] = symbol | |
| if "data_quality" not in report: | |
| report["data_quality"] = { | |
| "news_count_7d": 0, | |
| "missing_days": 0, | |
| "coverage_pct": 0, | |
| } | |
| if "top_influencers" not in report: | |
| report["top_influencers"] = [] | |
| # Re-label cached influencers so snapshots written before the | |
| # describe_feature() rollout also render human-readable names in the | |
| # UI. Non-destructive: pre-existing rich fields (label/description/ | |
| # category/time_horizon) are preserved; missing ones are back-filled. | |
| try: | |
| from app.features import describe_feature | |
| rebuilt: list[dict] = [] | |
| for infl in report.get("top_influencers", []) or []: | |
| if not isinstance(infl, dict): | |
| continue | |
| feature_key = infl.get("feature") or infl.get("name") or "" | |
| if not feature_key: | |
| rebuilt.append(infl) | |
| continue | |
| meta = describe_feature(str(feature_key)) | |
| enriched = { | |
| **infl, | |
| "feature": feature_key, | |
| "label": infl.get("label") or meta.get("label") or feature_key, | |
| "description": infl.get("description") or meta.get("description") or "", | |
| "category": infl.get("category") or meta.get("category") or "technical", | |
| "time_horizon": ( | |
| infl.get("time_horizon") | |
| or meta.get("time_horizon") | |
| or "intraday" | |
| ), | |
| } | |
| rebuilt.append(enriched) | |
| report["top_influencers"] = rebuilt | |
| except Exception as label_err: | |
| logger.warning(f"Influencer re-label skipped: {label_err}") | |
| logger.info(f"Returning snapshot for {symbol}: age={age_hours:.1f}h, state={quality_state}") | |
| return report | |
| async def get_history( | |
| symbol: str = Query(default="HG=F", description="Trading symbol"), | |
| days: int = Query(default=180, ge=7, le=730, description="Number of days of history") | |
| ): | |
| """ | |
| Get historical price and sentiment data. | |
| IMPORTANT: sentiment_index of 0.0 is a valid value (neutral sentiment), | |
| not the same as missing data. We return explicit 0.0 values. | |
| """ | |
| settings = get_settings() | |
| source = str(getattr(settings, "scoring_source", "news_articles")).strip().lower() | |
| with SessionLocal() as session: | |
| # Calculate date range | |
| end_date = datetime.now(timezone.utc) | |
| start_date = end_date - timedelta(days=days) | |
| # Query prices | |
| prices = session.query( | |
| PriceBar.date, | |
| PriceBar.close | |
| ).filter( | |
| PriceBar.symbol == symbol, | |
| PriceBar.date >= start_date | |
| ).order_by(PriceBar.date.asc()).all() | |
| if not prices: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"No price data found for {symbol}" | |
| ) | |
| # Query sentiment (prefer V2 when scoring source is news_processed) | |
| sentiments = [] | |
| if source == "news_processed": | |
| sentiments = session.query( | |
| DailySentimentV2.date, | |
| DailySentimentV2.sentiment_index, | |
| DailySentimentV2.news_count | |
| ).filter( | |
| DailySentimentV2.date >= start_date | |
| ).order_by(DailySentimentV2.date.asc()).all() | |
| if not sentiments: | |
| logger.warning("No rows in daily_sentiments_v2 for history; falling back to daily_sentiments") | |
| if not sentiments: | |
| sentiments = session.query( | |
| DailySentiment.date, | |
| DailySentiment.sentiment_index, | |
| DailySentiment.news_count | |
| ).filter( | |
| DailySentiment.date >= start_date | |
| ).order_by(DailySentiment.date.asc()).all() | |
| # Create sentiment lookup (by date string for easy matching) | |
| sentiment_lookup = {} | |
| for s in sentiments: | |
| date_str = s.date.strftime("%Y-%m-%d") if hasattr(s.date, 'strftime') else str(s.date)[:10] | |
| sentiment_lookup[date_str] = { | |
| "sentiment_index": s.sentiment_index, | |
| "news_count": s.news_count | |
| } | |
| # Build response data | |
| data_points = [] | |
| for price in prices: | |
| date_str = price.date.strftime("%Y-%m-%d") if hasattr(price.date, 'strftime') else str(price.date)[:10] | |
| sent = sentiment_lookup.get(date_str) | |
| # IMPORTANT: Use explicit values, don't convert 0.0 to None | |
| sentiment_idx = sent["sentiment_index"] if sent is not None else None | |
| news_count = sent["news_count"] if sent is not None else None | |
| data_points.append(HistoryDataPoint( | |
| date=date_str, | |
| price=round(price.close, 4), | |
| sentiment_index=sentiment_idx, | |
| sentiment_news_count=news_count | |
| )) | |
| return HistoryResponse( | |
| symbol=symbol, | |
| data=data_points | |
| ) | |
| async def health_check(): | |
| """ | |
| Perform system health check. | |
| Returns status information useful for monitoring and debugging. | |
| Includes Redis queue status and snapshot age for Faz 1 observability. | |
| """ | |
| settings = get_settings() | |
| model_dir = Path(settings.model_dir) | |
| # Count models | |
| models_found = 0 | |
| if model_dir.exists(): | |
| models_found = len(list(model_dir.glob("xgb_*_latest.json"))) | |
| # Get counts and snapshot age | |
| news_count = None | |
| price_count = None | |
| last_snapshot_age = None | |
| # Freshness metadata for the System page. Each field answers a distinct | |
| # question — see HealthResponse for the exact definitions. | |
| last_pipeline_run_at: Optional[str] = None | |
| last_pipeline_status: Optional[str] = None | |
| last_snapshot_generated_at: Optional[str] = None | |
| last_tft_prediction_at: Optional[str] = None | |
| tft_model_trained_at: Optional[str] = None | |
| tft_reference_price_date: Optional[str] = None | |
| price_bar_latest_date: Optional[str] = None | |
| price_bar_staleness_days: Optional[int] = None | |
| def _iso(dt): | |
| if dt is None: | |
| return None | |
| if dt.tzinfo is None: | |
| dt = dt.replace(tzinfo=timezone.utc) | |
| return dt.isoformat() | |
| try: | |
| with SessionLocal() as session: | |
| news_count = session.query(func.count(NewsArticle.id)).scalar() | |
| price_count = session.query(func.count(PriceBar.id)).scalar() | |
| from app.models import ( | |
| AnalysisSnapshot, | |
| PipelineRunMetrics, | |
| TFTModelMetadata, | |
| TFTPredictionSnapshot, | |
| ) | |
| # --- Authoritative pipeline run timestamp ------------------------ | |
| # Read from the actual worker metrics table, not the snapshot | |
| # table. Snapshots are only ONE artifact of a pipeline run; a | |
| # failed run still records a row here. | |
| latest_run = ( | |
| session.query(PipelineRunMetrics) | |
| .order_by(PipelineRunMetrics.run_started_at.desc()) | |
| .first() | |
| ) | |
| if latest_run is not None: | |
| ended = latest_run.run_completed_at or latest_run.run_started_at | |
| last_pipeline_run_at = _iso(ended) | |
| # Map internal run.status → external pipeline_status. | |
| # running → running | |
| # success → ok | |
| # failed → failed | |
| raw_status = (latest_run.status or "").lower() | |
| if raw_status == "success": | |
| last_pipeline_status = "ok" | |
| elif raw_status in {"running", "failed"}: | |
| last_pipeline_status = raw_status | |
| else: | |
| last_pipeline_status = raw_status or None | |
| # --- XGBoost snapshot age --------------------------------------- | |
| latest_snapshot = ( | |
| session.query(AnalysisSnapshot) | |
| .order_by(AnalysisSnapshot.generated_at.desc()) | |
| .first() | |
| ) | |
| if latest_snapshot and latest_snapshot.generated_at: | |
| snap_at = latest_snapshot.generated_at | |
| if snap_at.tzinfo is None: | |
| snap_at = snap_at.replace(tzinfo=timezone.utc) | |
| age = datetime.now(timezone.utc) - snap_at | |
| last_snapshot_age = int(age.total_seconds()) | |
| last_snapshot_generated_at = snap_at.isoformat() | |
| # If PipelineRunMetrics has no rows yet (fresh DB) fall back | |
| # to snapshot-derived status so older deployments don't go | |
| # blank. | |
| if last_pipeline_run_at is None: | |
| last_pipeline_run_at = last_snapshot_generated_at | |
| if last_pipeline_status is None: | |
| last_pipeline_status = ( | |
| "ok" if last_snapshot_age < 36 * 3600 else "stale" | |
| ) | |
| # --- Latest persisted TFT snapshot ------------------------------ | |
| latest_tft = ( | |
| session.query(TFTPredictionSnapshot) | |
| .filter(TFTPredictionSnapshot.symbol == "HG=F") | |
| .order_by(TFTPredictionSnapshot.generated_at.desc()) | |
| .first() | |
| ) | |
| if latest_tft is not None: | |
| last_tft_prediction_at = _iso(latest_tft.generated_at) | |
| tft_reference_price_date = latest_tft.reference_price_date | |
| # --- Latest TFT training timestamp ------------------------------ | |
| latest_tft_model = ( | |
| session.query(TFTModelMetadata) | |
| .filter(TFTModelMetadata.symbol == "HG=F") | |
| .order_by(TFTModelMetadata.trained_at.desc()) | |
| .first() | |
| ) | |
| if latest_tft_model is not None: | |
| tft_model_trained_at = _iso(latest_tft_model.trained_at) | |
| # --- PriceBar freshness ----------------------------------------- | |
| target = "HG=F" | |
| latest_bar = ( | |
| session.query(PriceBar.date) | |
| .filter(PriceBar.symbol == target) | |
| .order_by(PriceBar.date.desc()) | |
| .first() | |
| ) | |
| if latest_bar and latest_bar[0]: | |
| bar_date = latest_bar[0] | |
| if bar_date.tzinfo is None: | |
| bar_date = bar_date.replace(tzinfo=timezone.utc) | |
| price_bar_latest_date = bar_date.strftime("%Y-%m-%d") | |
| price_bar_staleness_days = max( | |
| int((datetime.now(timezone.utc) - bar_date).days), 0 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error getting counts: {e}") | |
| # Check Redis connectivity | |
| redis_ok = None | |
| try: | |
| from adapters.queue.redis import redis_healthcheck | |
| redis_result = await redis_healthcheck() | |
| redis_ok = redis_result.get("ok", False) | |
| except ImportError: | |
| # Redis adapter not available yet | |
| redis_ok = None | |
| except Exception as e: | |
| logger.warning(f"Redis healthcheck failed: {e}") | |
| redis_ok = False | |
| # Determine status | |
| pipeline_locked = is_pipeline_locked() | |
| if models_found == 0: | |
| status = "degraded" | |
| elif pipeline_locked: | |
| status = "degraded" | |
| elif redis_ok is False: | |
| status = "degraded" | |
| else: | |
| status = "healthy" | |
| return HealthResponse( | |
| status=status, | |
| db_type=get_db_type(), | |
| models_found=models_found, | |
| pipeline_locked=pipeline_locked, | |
| timestamp=datetime.now(timezone.utc).isoformat(), | |
| news_count=news_count, | |
| price_bars_count=price_count, | |
| redis_ok=redis_ok, | |
| last_snapshot_age_seconds=last_snapshot_age, | |
| last_pipeline_run_at=last_pipeline_run_at, | |
| last_pipeline_status=last_pipeline_status, | |
| last_snapshot_generated_at=last_snapshot_generated_at, | |
| last_tft_prediction_at=last_tft_prediction_at, | |
| tft_model_trained_at=tft_model_trained_at, | |
| tft_reference_price_date=tft_reference_price_date, | |
| price_bar_latest_date=price_bar_latest_date, | |
| price_bar_staleness_days=price_bar_staleness_days, | |
| ) | |
| async def get_market_prices(): | |
| """ | |
| Get live prices and daily changes for all tracked symbols. | |
| Uses yfinance for real-time data (15-minute delayed). | |
| Used by the Market Intelligence Map component. | |
| """ | |
| import yfinance as yf | |
| settings = get_settings() | |
| symbols = settings.symbols_list | |
| result = {} | |
| try: | |
| # Fetch all tickers at once for efficiency | |
| tickers = yf.Tickers(' '.join(symbols)) | |
| for symbol in symbols: | |
| try: | |
| ticker = tickers.tickers.get(symbol) | |
| if not ticker: | |
| result[symbol] = {"price": None, "change": None} | |
| continue | |
| info = ticker.info | |
| # Get current price and change | |
| current_price = info.get('regularMarketPrice') or info.get('currentPrice') | |
| change_pct = info.get('regularMarketChangePercent') | |
| if current_price is not None: | |
| result[symbol] = { | |
| "price": round(current_price, 4), | |
| "change": round(change_pct, 2) if change_pct else 0, | |
| } | |
| else: | |
| result[symbol] = {"price": None, "change": None} | |
| except Exception as e: | |
| logger.debug(f"Error fetching {symbol}: {e}") | |
| result[symbol] = {"price": None, "change": None} | |
| except Exception as e: | |
| logger.error(f"Error fetching market prices: {e}") | |
| return {"error": str(e), "symbols": {}} | |
| return {"symbols": result} | |
| # ============================================================================= | |
| # Market Heatmap Endpoint | |
| # ============================================================================= | |
| async def get_market_heatmap(background_tasks: BackgroundTasks): | |
| from app.models import HeatmapCache | |
| from app.heatmap import refresh_market_heatmap | |
| # Stuck refresh safety: if a refresh has been "in progress" for longer than | |
| # this, assume the worker crashed and allow a fresh background refresh to | |
| # be kicked off. yfinance batch fetch for the full universe finishes in | |
| # under ~2 minutes under normal conditions. | |
| STUCK_REFRESH_SECONDS = 180 | |
| with SessionLocal() as session: | |
| cache = session.query(HeatmapCache).first() | |
| now = datetime.now(timezone.utc) | |
| def _payload_count(payload) -> int: | |
| if not isinstance(payload, dict): | |
| return 0 | |
| total = 0 | |
| for grp in payload.get("children", []) or []: | |
| for sub in grp.get("children", []) or []: | |
| total += len(sub.get("children", []) or []) | |
| return total | |
| # If no cache or completely empty payload — trigger background refresh | |
| if not cache or not cache.payload_json or _payload_count(cache.payload_json) == 0: | |
| # Clear any stale "in progress" flag so we don't deadlock. | |
| if cache and cache.refresh_started_at is not None: | |
| started = cache.refresh_started_at | |
| if started.tzinfo is None: | |
| started = started.replace(tzinfo=timezone.utc) | |
| age = (now - started).total_seconds() | |
| if age > STUCK_REFRESH_SECONDS: | |
| cache.refresh_started_at = None | |
| session.commit() | |
| background_tasks.add_task(refresh_market_heatmap) | |
| return { | |
| "name": "CopperMind Universe", | |
| "children": [], | |
| "_meta": { | |
| "is_stale": True, | |
| "refresh_in_progress": True, | |
| "last_updated_at": None, | |
| "next_refresh_at": None, | |
| "source_delay_minutes": 15, | |
| "payload_count": 0, | |
| "refresh_error": cache.refresh_error if cache else None, | |
| "cache_state": "empty", | |
| }, | |
| } | |
| # Check if stale | |
| expires_at = cache.expires_at | |
| if expires_at and expires_at.tzinfo is None: | |
| expires_at = expires_at.replace(tzinfo=timezone.utc) | |
| is_stale = now > expires_at if expires_at else True | |
| refresh_in_progress = cache.refresh_started_at is not None | |
| # Recover from stuck "in progress" flags | |
| if refresh_in_progress: | |
| started = cache.refresh_started_at | |
| if started.tzinfo is None: | |
| started = started.replace(tzinfo=timezone.utc) | |
| if (now - started).total_seconds() > STUCK_REFRESH_SECONDS: | |
| logger.warning( | |
| "Heatmap refresh appears stuck (started %.0fs ago) — clearing flag", | |
| (now - started).total_seconds(), | |
| ) | |
| cache.refresh_started_at = None | |
| session.commit() | |
| refresh_in_progress = False | |
| if is_stale and not refresh_in_progress: | |
| background_tasks.add_task(refresh_market_heatmap) | |
| cache.refresh_started_at = now | |
| session.commit() | |
| refresh_in_progress = True | |
| payload = cache.payload_json | |
| payload_count = _payload_count(payload) | |
| cache_state = "fresh" | |
| if is_stale: | |
| cache_state = "stale" | |
| if refresh_in_progress: | |
| cache_state = "refreshing" | |
| if isinstance(payload, dict): | |
| payload["_meta"] = { | |
| "is_stale": is_stale, | |
| "refresh_in_progress": refresh_in_progress, | |
| "last_updated_at": cache.cached_at.isoformat() if cache.cached_at else None, | |
| "next_refresh_at": cache.expires_at.isoformat() if cache.expires_at else None, | |
| "source_delay_minutes": 15, | |
| "payload_count": payload_count, | |
| "refresh_error": cache.refresh_error, | |
| "cache_state": cache_state, | |
| } | |
| return payload | |
| # ============================================================================= | |
| # Live Price Endpoint (Twelve Data - Real-time) | |
| # ============================================================================= | |
| async def get_live_price(): | |
| """ | |
| Get real-time copper price from Twelve Data. | |
| Used for the header price display. Separate from yfinance to avoid rate limits. | |
| """ | |
| import httpx | |
| settings = get_settings() | |
| if not settings.twelvedata_api_key: | |
| logger.warning("Twelve Data API key not configured") | |
| return {"price": None, "error": "API key not configured"} | |
| try: | |
| async with httpx.AsyncClient(timeout=10.0) as client: | |
| response = await client.get( | |
| "https://api.twelvedata.com/price", | |
| params={ | |
| "symbol": "XCU/USD", | |
| "apikey": settings.twelvedata_api_key, | |
| } | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| price = data.get("price") | |
| if price: | |
| return { | |
| "symbol": "XCU/USD", | |
| "price": round(float(price), 4), | |
| "error": None, | |
| } | |
| else: | |
| return {"price": None, "error": data.get("message", "No price data")} | |
| else: | |
| return {"price": None, "error": f"API error: {response.status_code}"} | |
| except Exception as e: | |
| from app.settings import mask_api_key | |
| logger.error(f"Twelve Data API error: {mask_api_key(str(e))}") | |
| return {"price": None, "error": "API error"} | |
| # ============================================================================= | |
| # WebSocket Live Price Streaming (Twelve Data) | |
| # ============================================================================= | |
| async def websocket_live_price(websocket: WebSocket): | |
| """ | |
| WebSocket endpoint for real-time copper price streaming. | |
| Connects to Twelve Data WebSocket and relays price events to the client. | |
| """ | |
| import websockets | |
| import asyncio | |
| import json | |
| await websocket.accept() | |
| settings = get_settings() | |
| if not settings.twelvedata_api_key: | |
| await websocket.send_json({"error": "API key not configured"}) | |
| await websocket.close() | |
| return | |
| td_ws_url = f"wss://ws.twelvedata.com/v1/quotes?apikey={settings.twelvedata_api_key}" | |
| try: | |
| async with websockets.connect(td_ws_url) as td_ws: | |
| # Subscribe to BTC/USD first (for testing Basic plan support) | |
| # If BTC works but XCU doesn't, it means commodities need Pro plan | |
| subscribe_msg = json.dumps({ | |
| "action": "subscribe", | |
| "params": {"symbols": "BTC/USD"} | |
| }) | |
| await td_ws.send(subscribe_msg) | |
| logger.info("Subscribed to BTC/USD via Twelve Data WebSocket (testing)") | |
| # Heartbeat task to keep connection alive | |
| async def send_heartbeat(): | |
| while True: | |
| await asyncio.sleep(10) | |
| try: | |
| await td_ws.send(json.dumps({"action": "heartbeat"})) | |
| except Exception: | |
| break | |
| heartbeat_task = asyncio.create_task(send_heartbeat()) | |
| try: | |
| # Relay messages from Twelve Data to client | |
| async for message in td_ws: | |
| data = json.loads(message) | |
| if data.get("event") == "price": | |
| await websocket.send_json({ | |
| "symbol": data.get("symbol"), | |
| "price": data.get("price"), | |
| "timestamp": data.get("timestamp"), | |
| }) | |
| elif data.get("event") == "subscribe-status": | |
| logger.info(f"Subscription status: {data.get('status')}") | |
| if data.get("fails"): | |
| logger.warning(f"Subscription failures: {data.get('fails')}") | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected from live-price WebSocket") | |
| finally: | |
| heartbeat_task.cancel() | |
| except Exception as e: | |
| # Mask potential API keys in error messages | |
| from app.settings import mask_api_key | |
| safe_error = mask_api_key(str(e)) | |
| logger.error(f"WebSocket error: {safe_error}") | |
| try: | |
| await websocket.send_json({"error": "Connection error"}) # Don't expose details | |
| except Exception: | |
| pass | |
| # ============================================================================= | |
| # AI Commentary Endpoint | |
| # ============================================================================= | |
| async def get_commentary( | |
| symbol: str = Query(default="HG=F", description="Symbol to get commentary for") | |
| ): | |
| """ | |
| Get AI commentary for the specified symbol. | |
| Commentary is generated once after each pipeline run and stored in the database. | |
| This endpoint simply returns the stored commentary without making new API calls. | |
| """ | |
| from app.commentary import get_commentary_from_db | |
| with SessionLocal() as session: | |
| result = get_commentary_from_db(session, symbol) | |
| if result: | |
| return { | |
| "symbol": symbol, | |
| "commentary": result["commentary"], | |
| "error": None, | |
| "generated_at": result["generated_at"], | |
| "ai_stance": result.get("ai_stance", "NEUTRAL"), | |
| } | |
| else: | |
| return { | |
| "symbol": symbol, | |
| "commentary": None, | |
| "error": "No commentary available. Commentary is generated after pipeline runs.", | |
| "generated_at": None, | |
| "ai_stance": "NEUTRAL", | |
| } | |
| # ============================================================================= | |
| # Root redirect (optional convenience) | |
| # ============================================================================= | |
| _tft_cache: dict = {} | |
| _TFT_CACHE_TTL_S = 300 # 5 minutes | |
| async def get_tft_analysis( | |
| symbol: str = "HG=F", | |
| source: str = "snapshot", | |
| ): | |
| """ | |
| Get TFT-ASRO analysis for the given symbol. | |
| `source` semantics: | |
| * `snapshot` (default) — serve the latest persisted TFTPredictionSnapshot | |
| written by the worker. If none exists, transparently fall back to live. | |
| * `live` — always run a fresh inference. In-memory cached | |
| for 5 minutes to protect the worker against the 60s polling loop. | |
| """ | |
| source = (source or "snapshot").strip().lower() | |
| if source not in {"snapshot", "live"}: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="source must be one of: snapshot, live", | |
| ) | |
| # --- 1. Try persisted snapshot ------------------------------------------ | |
| if source == "snapshot": | |
| try: | |
| from app.models import TFTPredictionSnapshot | |
| with SessionLocal() as session: | |
| latest = ( | |
| session.query(TFTPredictionSnapshot) | |
| .filter(TFTPredictionSnapshot.symbol == symbol) | |
| .order_by(TFTPredictionSnapshot.generated_at.desc()) | |
| .first() | |
| ) | |
| if latest is not None and isinstance(latest.payload_json, dict): | |
| payload = dict(latest.payload_json) | |
| gen_at = latest.generated_at | |
| if gen_at and gen_at.tzinfo is None: | |
| gen_at = gen_at.replace(tzinfo=timezone.utc) | |
| payload["source"] = "snapshot" | |
| payload["snapshot_generated_at"] = ( | |
| gen_at.isoformat() if gen_at else None | |
| ) | |
| return payload | |
| except Exception as exc: | |
| logger.warning( | |
| "TFT snapshot read failed, falling back to live inference: %s", | |
| exc, | |
| ) | |
| # No snapshot yet — silently fall through to live inference so the | |
| # UI can still show something on first deployment. | |
| # --- 2. Live inference (explicit request or snapshot miss) -------------- | |
| now = datetime.now(timezone.utc) | |
| cache_key = f"{symbol}:live" | |
| cached = _tft_cache.get(cache_key) | |
| if cached: | |
| age = (now - cached["ts"]).total_seconds() | |
| if age < _TFT_CACHE_TTL_S: | |
| return cached["data"] | |
| try: | |
| from deep_learning.inference.predictor import generate_tft_analysis | |
| with SessionLocal() as session: | |
| result = generate_tft_analysis(session, symbol) | |
| if "error" in result: | |
| raise HTTPException(status_code=500, detail=result["error"]) | |
| result = dict(result) | |
| result["source"] = "live" | |
| _tft_cache[cache_key] = {"data": result, "ts": now} | |
| return result | |
| except FileNotFoundError: | |
| raise HTTPException( | |
| status_code=404, | |
| detail="TFT-ASRO model not trained yet. Run training pipeline first.", | |
| ) | |
| except ImportError as exc: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"TFT-ASRO module not available: {exc}", | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as exc: | |
| logger.error("TFT analysis failed: %s", exc, exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| async def root_redirect(): | |
| """Redirect root to API docs.""" | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/api/docs") | |
| async def api_root(): | |
| """API root information.""" | |
| return { | |
| "name": "CopperMind API", | |
| "version": "1.0.0", | |
| "docs": "/api/docs", | |
| "health": "/api/health" | |
| } | |
| # ============================================================================= | |
| # Pipeline Management Endpoints | |
| # ============================================================================= | |
| def verify_pipeline_secret(authorization: Optional[str] = Header(None)) -> None: | |
| """ | |
| Verify the pipeline trigger secret from Authorization header. | |
| Expected format: Authorization: Bearer <PIPELINE_TRIGGER_SECRET> | |
| """ | |
| settings = get_settings() | |
| # If no secret is configured, reject all requests (fail secure) | |
| if not settings.pipeline_trigger_secret: | |
| logger.warning("Pipeline trigger attempted but PIPELINE_TRIGGER_SECRET not configured") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Pipeline trigger authentication not configured. Set PIPELINE_TRIGGER_SECRET." | |
| ) | |
| # Check Authorization header | |
| if not authorization: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Missing Authorization header. Expected: Bearer <token>" | |
| ) | |
| # Parse Bearer token | |
| parts = authorization.split(" ", 1) | |
| if len(parts) != 2 or parts[0].lower() != "bearer": | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid Authorization format. Expected: Bearer <token>" | |
| ) | |
| token = parts[1] | |
| # Constant-time comparison to prevent timing attacks | |
| import secrets | |
| if not secrets.compare_digest(token, settings.pipeline_trigger_secret): | |
| logger.warning("Pipeline trigger attempted with invalid token") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid pipeline trigger token" | |
| ) | |
| logger.info("Pipeline trigger authorized successfully") | |
| async def trigger_pipeline( | |
| train_model: bool = Query(default=False, description="Train/retrain XGBoost model"), | |
| trigger_source: str = Query(default="api", description="Source of trigger (api, cron, manual)"), | |
| _auth: None = Depends(verify_pipeline_secret), | |
| ): | |
| """ | |
| Enqueue a pipeline job to Redis queue. | |
| This endpoint does NOT run the pipeline - it only enqueues a job. | |
| The worker service consumes and executes the job. | |
| Returns: | |
| run_id: UUID for tracking this pipeline run | |
| enqueued: True if job was enqueued successfully | |
| """ | |
| # Check if pipeline is already running (advisory lock check) | |
| # Note: This is a weak check - the worker will do the authoritative lock check | |
| if is_pipeline_locked(): | |
| raise HTTPException( | |
| status_code=409, | |
| detail="Pipeline is already running. Please wait for it to complete." | |
| ) | |
| try: | |
| from adapters.queue.jobs import enqueue_pipeline_job | |
| result = await enqueue_pipeline_job( | |
| train_model=train_model, | |
| trigger_source=trigger_source, | |
| ) | |
| logger.info(f"Pipeline job enqueued: run_id={result['run_id']}, trigger={trigger_source}") | |
| return { | |
| "status": "enqueued", | |
| "message": "Pipeline job enqueued. Worker will execute. Check /api/health for status.", | |
| "run_id": result["run_id"], | |
| "job_id": result["job_id"], | |
| "train_model": train_model, | |
| "trigger_source": trigger_source, | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to enqueue pipeline job: {e}") | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Failed to enqueue job. Redis may be unavailable: {str(e)}" | |
| ) | |
| # ============================================================================= | |
| # New User-Facing Endpoints | |
| # ============================================================================= | |
| async def get_consensus( | |
| symbol: str = Query(default="HG=F", description="Trading symbol") | |
| ): | |
| from deep_learning.inference.predictor import ensemble_directional_vote, generate_tft_analysis | |
| # 1. Get TFT analysis | |
| try: | |
| with SessionLocal() as session: | |
| tft_result = generate_tft_analysis(session, symbol) | |
| if "error" in tft_result: | |
| raise HTTPException(status_code=500, detail=tft_result["error"]) | |
| tft_return = tft_result.get("prediction", {}).get("predicted_return_median", 0.0) | |
| except Exception as e: | |
| logger.error(f"Failed to get TFT analysis for consensus: {e}") | |
| tft_return = 0.0 | |
| # 2. Get XGBoost analysis (latest snapshot) | |
| xgb_return = 0.0 | |
| try: | |
| with SessionLocal() as session: | |
| snapshot = session.query(AnalysisSnapshot).filter( | |
| AnalysisSnapshot.symbol == symbol | |
| ).order_by(AnalysisSnapshot.generated_at.desc()).first() | |
| if snapshot and snapshot.report_json: | |
| xgb_return = snapshot.report_json.get("predicted_return", 0.0) | |
| except Exception as e: | |
| logger.error(f"Failed to get XGBoost analysis for consensus: {e}") | |
| # 3. Calculate consensus | |
| xgb_bias_correction = 0.001 # Hardcoded small bias correction for now | |
| result = ensemble_directional_vote(xgb_return, tft_return, xgb_bias_correction) | |
| return result | |
| async def get_tft_summary( | |
| symbol: str = Query(default="HG=F", description="Target symbol") | |
| ): | |
| from app.models import TFTModelMetadata | |
| from app.quality_gate import evaluate_quality_gate | |
| import json | |
| with SessionLocal() as session: | |
| meta = session.query(TFTModelMetadata).filter( | |
| TFTModelMetadata.symbol == symbol | |
| ).order_by(TFTModelMetadata.trained_at.desc()).first() | |
| if not meta: | |
| raise HTTPException(status_code=404, detail=f"No TFT model metadata found for {symbol}") | |
| config = json.loads(meta.config_json) if meta.config_json else {} | |
| metrics = json.loads(meta.metrics_json) if meta.metrics_json else {} | |
| # Variable importance not directly in TFTModelMetadata yet, extract from latest artifacts if available | |
| # But we can try to find it in the artifacts folder | |
| variable_importance = [] | |
| try: | |
| import pathlib | |
| from .features import describe_feature | |
| artifact_dir = pathlib.Path(config.get("feature_store", {}).get("artifact_dir", "artifacts/feature_store")) | |
| mrmr_path = artifact_dir / "latest" / "mrmr_results.json" | |
| if mrmr_path.exists(): | |
| mrmr_data = json.loads(mrmr_path.read_text(encoding="utf-8")) | |
| for feat, imp in mrmr_data.get("scores", {}).items(): | |
| meta_desc = describe_feature(feat) | |
| variable_importance.append({ | |
| "feature": feat, | |
| "importance": float(imp), | |
| "label": meta_desc["label"], | |
| "description": meta_desc["description"], | |
| "category": meta_desc["category"], | |
| "time_horizon": meta_desc.get("time_horizon", ""), | |
| }) | |
| variable_importance.sort(key=lambda x: x["importance"], reverse=True) | |
| variable_importance = variable_importance[:20] | |
| except Exception as e: | |
| logger.warning(f"Could not load variable importance: {e}") | |
| da = metrics.get("directional_accuracy", 0.5) | |
| sharpe = metrics.get("sharpe_ratio", 0.0) | |
| vr = metrics.get("variance_ratio", 1.0) | |
| passed, reasons = evaluate_quality_gate(da, sharpe, vr) | |
| return { | |
| "symbol": symbol, | |
| "trained_at": meta.trained_at.isoformat() if meta.trained_at else None, | |
| "checkpoint_path": meta.checkpoint_path, | |
| "config": config, | |
| "metrics": metrics, | |
| "variable_importance": variable_importance, | |
| "quality_gate": { | |
| "passed": passed, | |
| "reasons": reasons, | |
| "metrics": {"da": da, "sharpe": sharpe, "vr": vr} | |
| } | |
| } | |
| async def get_latest_backtest(symbol: str = Query(default="HG=F", description="Target symbol")): | |
| import pathlib | |
| import json as _json | |
| from app.models import BacktestReport | |
| empty_payload = { | |
| "available": False, | |
| "message": "No backtest runs yet. Run `python -m backend.backtest.runner` to generate one.", | |
| "report_date": None, | |
| "summary_metrics": {}, | |
| "window_metrics": [], | |
| "theta_comparison": {}, | |
| "verdict": None, | |
| } | |
| # 1. Prefer DB-persisted row (production-friendly across container restarts) | |
| try: | |
| with SessionLocal() as session: | |
| row = ( | |
| session.query(BacktestReport) | |
| .filter(BacktestReport.symbol == symbol) | |
| .order_by(BacktestReport.generated_at.desc()) | |
| .first() | |
| ) | |
| if row is not None: | |
| return { | |
| "available": True, | |
| "report_date": row.generated_at.isoformat() if row.generated_at else None, | |
| "summary_metrics": row.summary_json or {}, | |
| "window_metrics": row.windows_json or [], | |
| "theta_comparison": row.theta_comparison_json or {}, | |
| "verdict": row.verdict, | |
| } | |
| except Exception as e: | |
| logger.warning(f"BacktestReport table read failed, falling back to FS: {e}") | |
| # 2. Fallback: legacy filesystem artifact (local dev) | |
| backtest_dir = pathlib.Path("artifacts/backtest") | |
| if backtest_dir.exists(): | |
| reports = list(backtest_dir.glob("backtest_*.json")) | |
| if reports: | |
| latest_report_path = max(reports, key=lambda p: p.stat().st_mtime) | |
| try: | |
| data = _json.loads(latest_report_path.read_text(encoding="utf-8")) | |
| tft_bt = data.get("tft_backtest", {}) | |
| comp = data.get("baseline_comparison", {}) | |
| return { | |
| "available": True, | |
| "report_date": data.get("timestamp") or data.get("generated_at"), | |
| "summary_metrics": tft_bt.get("summary", {}), | |
| "window_metrics": tft_bt.get("windows", []), | |
| "theta_comparison": comp, | |
| "verdict": comp.get("verdict"), | |
| } | |
| except Exception as e: | |
| logger.error(f"Error reading backtest report: {e}") | |
| # 3. Empty state (no 404, no error) | |
| return empty_payload | |
| # ============================================================================= | |
| # Sentiment Summary — Stable, DB-backed, NO LLM on the hot path. | |
| # ============================================================================= | |
| # | |
| # Architecture contract (frontend should depend on this shape forever): | |
| # - `index`: blended daily sentiment in [-1, +1] | |
| # - `label`: Bullish / Neutral / Bearish (derived from `index`) | |
| # - `source`: which aggregate layer produced the value | |
| # ("daily_v2" | "rolling_v2" | "legacy_v1" | "none") | |
| # - `components`: breakdown of LLM vs FinBERT vs rule_sign contributions | |
| # - `trend_7d`: list of {date, index, news_count} for sparkline | |
| # - `recent_articles`: a small sample of latest processed headlines | |
| # - `data_freshness`: {oldest, newest, age_hours, article_count_24h} | |
| # | |
| # This endpoint NEVER calls an LLM. Commentary generation (which does use | |
| # OpenRouter) is pipeline-driven and cached in `AICommentary`. | |
| # ============================================================================= | |
| async def get_sentiment_summary( | |
| days: int = Query(default=7, ge=1, le=30, description="Trend window in days"), | |
| recent_limit: int = Query(default=6, ge=1, le=20, description="Recent headlines to include"), | |
| ): | |
| from sqlalchemy import func, desc | |
| def _label(idx: float) -> str: | |
| if idx > 0.10: | |
| return "Bullish" | |
| if idx < -0.10: | |
| return "Bearish" | |
| return "Neutral" | |
| with SessionLocal() as session: | |
| now = datetime.now(timezone.utc) | |
| window_start = now - timedelta(days=days) | |
| # ---- 1) Preferred source: DailySentimentV2 (commodity-aware) ---- | |
| v2_rows = ( | |
| session.query(DailySentimentV2) | |
| .filter(DailySentimentV2.date >= window_start) | |
| .order_by(DailySentimentV2.date.asc()) | |
| .all() | |
| ) | |
| # ---- 2) Component breakdown from NewsSentimentV2 (same window) ---- | |
| # Published date lives on NewsRaw, so we join processed → raw. | |
| component_rows = ( | |
| session.query( | |
| func.avg(NewsSentimentV2.impact_score_llm).label("avg_llm"), | |
| func.avg(NewsSentimentV2.finbert_pos - NewsSentimentV2.finbert_neg).label("avg_finbert"), | |
| func.avg(NewsSentimentV2.rule_sign).label("avg_rule"), | |
| func.avg(NewsSentimentV2.confidence_calibrated).label("avg_conf"), | |
| func.avg(NewsSentimentV2.relevance_score).label("avg_rel"), | |
| func.count(NewsSentimentV2.id).label("n"), | |
| ) | |
| .join(NewsProcessed, NewsProcessed.id == NewsSentimentV2.news_processed_id) | |
| .join(NewsRaw, NewsRaw.id == NewsProcessed.raw_id) | |
| .filter(NewsRaw.published_at >= window_start) | |
| .one() | |
| ) | |
| # ---- 3) Pick the freshest possible index ---- | |
| index_val: float = 0.0 | |
| source = "none" | |
| avg_confidence: Optional[float] = None | |
| if v2_rows: | |
| latest_v2 = v2_rows[-1] | |
| index_val = float(latest_v2.sentiment_index or 0.0) | |
| avg_confidence = float(latest_v2.avg_confidence or 0.0) if latest_v2.avg_confidence is not None else None | |
| source = "daily_v2" | |
| elif component_rows and component_rows.n and component_rows.n > 0: | |
| # No daily aggregate yet — fall back to rolling per-article avg | |
| llm = float(component_rows.avg_llm or 0.0) | |
| fb = float(component_rows.avg_finbert or 0.0) | |
| rule = float(component_rows.avg_rule or 0.0) | |
| index_val = 0.5 * llm + 0.3 * fb + 0.2 * rule | |
| avg_confidence = float(component_rows.avg_conf or 0.0) | |
| source = "rolling_v2" | |
| else: | |
| # Last-ditch fallback: legacy DailySentiment | |
| legacy = ( | |
| session.query(DailySentiment) | |
| .order_by(DailySentiment.date.desc()) | |
| .first() | |
| ) | |
| if legacy is not None: | |
| index_val = float(legacy.sentiment_index or 0.0) | |
| source = "legacy_v1" | |
| # ---- 4) Build trend series for sparkline ---- | |
| trend_7d = [ | |
| { | |
| "date": r.date.isoformat() if r.date else None, | |
| "index": float(r.sentiment_index or 0.0), | |
| "news_count": int(r.news_count or 0), | |
| } | |
| for r in v2_rows | |
| ] | |
| # ---- 5) Recent articles (hybrid: raw news + processed + V2 score) ---- | |
| recent_q = ( | |
| session.query(NewsRaw, NewsProcessed, NewsSentimentV2) | |
| .join(NewsProcessed, NewsProcessed.raw_id == NewsRaw.id) | |
| .outerjoin( | |
| NewsSentimentV2, | |
| NewsSentimentV2.news_processed_id == NewsProcessed.id, | |
| ) | |
| .order_by(desc(NewsRaw.published_at)) | |
| .limit(recent_limit) | |
| .all() | |
| ) | |
| recent_articles = [] | |
| for raw, proc, score in recent_q: | |
| recent_articles.append({ | |
| "title": getattr(raw, "title", None) or getattr(proc, "canonical_title", None) or "", | |
| "source": getattr(raw, "source", None), | |
| "url": getattr(raw, "url", None), | |
| "published_at": raw.published_at.isoformat() if getattr(raw, "published_at", None) else None, | |
| "sentiment": { | |
| "label": score.label if score else None, | |
| "final_score": float(score.final_score) if score else None, | |
| "relevance": float(score.relevance_score) if score else None, | |
| "confidence": float(score.confidence_calibrated) if score else None, | |
| "event_type": score.event_type if score else None, | |
| } if score else None, | |
| }) | |
| # ---- 6) Data freshness (lives on NewsRaw, not Processed) ---- | |
| freshness_q = session.query( | |
| func.min(NewsRaw.published_at).label("oldest"), | |
| func.max(NewsRaw.published_at).label("newest"), | |
| func.count(NewsRaw.id).label("n_total"), | |
| ).filter(NewsRaw.published_at >= (now - timedelta(hours=24))).one() | |
| newest = freshness_q.newest | |
| age_hours = ((now - newest).total_seconds() / 3600.0) if newest else None | |
| return { | |
| "index": round(float(index_val), 4), | |
| "label": _label(index_val), | |
| "source": source, | |
| "components": { | |
| "llm_impact_avg": float(component_rows.avg_llm) if component_rows.avg_llm is not None else None, | |
| "finbert_pn_avg": float(component_rows.avg_finbert) if component_rows.avg_finbert is not None else None, | |
| "rule_sign_avg": float(component_rows.avg_rule) if component_rows.avg_rule is not None else None, | |
| "avg_confidence": avg_confidence, | |
| "avg_relevance": float(component_rows.avg_rel) if component_rows.avg_rel is not None else None, | |
| "sample_size": int(component_rows.n or 0), | |
| }, | |
| "trend": trend_7d, | |
| "recent_articles": recent_articles, | |
| "data_freshness": { | |
| "newest": newest.isoformat() if newest else None, | |
| "oldest": freshness_q.oldest.isoformat() if freshness_q.oldest else None, | |
| "age_hours": round(age_hours, 2) if age_hours is not None else None, | |
| "article_count_24h": int(freshness_q.n_total or 0), | |
| }, | |
| "generated_at": now.isoformat(), | |
| } | |
| # ============================================================================= | |
| # News intelligence endpoints | |
| # ============================================================================= | |
| # | |
| # Serves the Overview right-sidebar news feed. Reads from the news_raw/ | |
| # news_processed/news_sentiments_v2 pipeline the daily worker already fills — | |
| # no LLM is invoked on the hot path. | |
| # | |
| # Source taxonomy: | |
| # * channel = ingestion channel (NewsRaw.source): "google_news" | "newsapi" | |
| # * publisher = original publisher (raw_payload.source): Reuters, Mining.com… | |
| # ============================================================================= | |
| _news_list_cache: dict[tuple, tuple[float, dict]] = {} | |
| _news_stats_cache: dict[int, tuple[float, dict]] = {} | |
| _NEWS_LIST_TTL_S = 60.0 | |
| _NEWS_STATS_TTL_S = 120.0 | |
| _VALID_LABELS = {"BULLISH", "BEARISH", "NEUTRAL"} | |
| def _extract_publisher(raw_payload) -> Optional[str]: | |
| """Pull the original publisher name out of a NewsRaw.raw_payload blob.""" | |
| if not raw_payload: | |
| return None | |
| if isinstance(raw_payload, str): | |
| try: | |
| import json as _json | |
| raw_payload = _json.loads(raw_payload) | |
| except (ValueError, TypeError): | |
| return None | |
| if not isinstance(raw_payload, dict): | |
| return None | |
| src = raw_payload.get("source") | |
| if isinstance(src, dict): | |
| name = src.get("name") or src.get("title") | |
| return str(name) if name else None | |
| if isinstance(src, str) and src.strip(): | |
| return src.strip() | |
| name = raw_payload.get("publisher") or raw_payload.get("author") | |
| return str(name) if name else None | |
| def _build_news_sentiment_block(sent: Optional[NewsSentimentV2]) -> Optional[NewsSentimentBlock]: | |
| if sent is None: | |
| return None | |
| return NewsSentimentBlock( | |
| label=sent.label, | |
| final_score=float(sent.final_score) if sent.final_score is not None else None, | |
| impact_score_llm=float(sent.impact_score_llm) if sent.impact_score_llm is not None else None, | |
| confidence=float(sent.confidence_calibrated) if sent.confidence_calibrated is not None else None, | |
| relevance=float(sent.relevance_score) if sent.relevance_score is not None else None, | |
| event_type=sent.event_type, | |
| finbert=NewsFinbertProbs( | |
| pos=float(sent.finbert_pos or 0.0), | |
| neu=float(sent.finbert_neu or 0.0), | |
| neg=float(sent.finbert_neg or 0.0), | |
| ), | |
| reasoning=_extract_reasoning_text(sent.reasoning_json), | |
| scored_at=sent.scored_at.isoformat() if sent.scored_at else None, | |
| ) | |
| def _extract_reasoning_text(reasoning_json: Optional[str]) -> Optional[str]: | |
| """Pull a short human-readable rationale out of the cached JSON blob.""" | |
| if not reasoning_json: | |
| return None | |
| try: | |
| import json as _json | |
| blob = _json.loads(reasoning_json) | |
| except (ValueError, TypeError): | |
| return str(reasoning_json)[:500] if reasoning_json else None | |
| if isinstance(blob, dict): | |
| for key in ("reasoning", "rationale", "summary", "explanation"): | |
| val = blob.get(key) | |
| if isinstance(val, str) and val.strip(): | |
| return val.strip()[:500] | |
| return None | |
| if isinstance(blob, str): | |
| return blob[:500] | |
| return None | |
| async def get_news_feed( | |
| limit: int = Query(default=20, ge=1, le=50), | |
| offset: int = Query(default=0, ge=0), | |
| since_hours: int = Query(default=48, ge=1, le=168), | |
| label: str = Query(default="all"), | |
| event_type: str = Query(default="all"), | |
| min_relevance: float = Query(default=0.0, ge=0.0, le=1.0), | |
| channel: str = Query(default="all"), | |
| publisher: Optional[str] = Query(default=None, max_length=200), | |
| search: Optional[str] = Query(default=None, max_length=200), | |
| ): | |
| from sqlalchemy import desc as _desc | |
| filters_echo = { | |
| "limit": limit, | |
| "offset": offset, | |
| "since_hours": since_hours, | |
| "label": label, | |
| "event_type": event_type, | |
| "min_relevance": min_relevance, | |
| "channel": channel, | |
| "publisher": publisher, | |
| "search": search, | |
| } | |
| cache_key = tuple(sorted(filters_echo.items())) | |
| now_ts = datetime.now(timezone.utc).timestamp() | |
| cached = _news_list_cache.get(cache_key) | |
| if cached and (now_ts - cached[0]) < _NEWS_LIST_TTL_S: | |
| return cached[1] | |
| label_upper = label.upper() | |
| if label_upper != "ALL" and label_upper not in _VALID_LABELS: | |
| raise HTTPException(status_code=400, detail=f"Invalid label '{label}'") | |
| with SessionLocal() as session: | |
| now = datetime.now(timezone.utc) | |
| cutoff = now - timedelta(hours=since_hours) | |
| q = ( | |
| session.query(NewsRaw, NewsProcessed, NewsSentimentV2) | |
| .join(NewsProcessed, NewsProcessed.raw_id == NewsRaw.id) | |
| .outerjoin( | |
| NewsSentimentV2, | |
| NewsSentimentV2.news_processed_id == NewsProcessed.id, | |
| ) | |
| .filter(NewsRaw.published_at >= cutoff) | |
| ) | |
| if channel.lower() != "all": | |
| q = q.filter(NewsRaw.source == channel) | |
| if event_type.lower() != "all": | |
| q = q.filter(NewsSentimentV2.event_type == event_type) | |
| if label_upper != "ALL": | |
| q = q.filter(NewsSentimentV2.label == label_upper) | |
| if min_relevance > 0: | |
| q = q.filter(NewsSentimentV2.relevance_score >= min_relevance) | |
| if search: | |
| q = q.filter(NewsRaw.title.ilike(f"%{search}%")) | |
| q = q.order_by(_desc(NewsRaw.published_at)) | |
| publisher_needle = publisher.strip().lower() if publisher and publisher.strip() else None | |
| if publisher_needle: | |
| # Publisher filter requires JSON extraction; do it in Python to | |
| # remain backend-agnostic (sqlite/postgres) and keep the endpoint | |
| # simple. Scope is bounded by the time window filter above. | |
| rows = q.limit(500).all() | |
| filtered = [ | |
| triple for triple in rows | |
| if ( | |
| _extract_publisher(triple[0].raw_payload) or "" | |
| ).lower().find(publisher_needle) >= 0 | |
| ] | |
| total = len(filtered) | |
| page_rows = filtered[offset: offset + limit] | |
| else: | |
| total = q.count() | |
| page_rows = q.offset(offset).limit(limit).all() | |
| items: list[NewsItem] = [] | |
| for raw, processed, sentiment in page_rows: | |
| items.append( | |
| NewsItem( | |
| id=int(processed.id), | |
| raw_id=int(raw.id), | |
| title=str(raw.title or ""), | |
| description=str(raw.description or "") or None, | |
| url=str(raw.url or "") or None, | |
| channel=str(raw.source or "unknown"), | |
| publisher=_extract_publisher(raw.raw_payload), | |
| source_feed=str(raw.source_feed or "") or None, | |
| published_at=raw.published_at.isoformat() if raw.published_at else None, | |
| fetched_at=raw.fetched_at.isoformat() if raw.fetched_at else None, | |
| language=str(processed.language or "") or None, | |
| sentiment=_build_news_sentiment_block(sentiment), | |
| ) | |
| ) | |
| response = NewsListResponse( | |
| items=items, | |
| total=int(total), | |
| limit=limit, | |
| offset=offset, | |
| has_more=(offset + limit) < int(total), | |
| generated_at=now.isoformat(), | |
| filters=filters_echo, | |
| ) | |
| payload = response.model_dump() | |
| _news_list_cache[cache_key] = (now_ts, payload) | |
| # Trim cache to avoid unbounded growth. | |
| if len(_news_list_cache) > 128: | |
| oldest = sorted(_news_list_cache.items(), key=lambda kv: kv[1][0])[: len(_news_list_cache) - 128] | |
| for k, _ in oldest: | |
| _news_list_cache.pop(k, None) | |
| return payload | |
| async def get_news_stats( | |
| since_hours: int = Query(default=24, ge=1, le=168), | |
| ): | |
| now_ts = datetime.now(timezone.utc).timestamp() | |
| cached = _news_stats_cache.get(since_hours) | |
| if cached and (now_ts - cached[0]) < _NEWS_STATS_TTL_S: | |
| return cached[1] | |
| with SessionLocal() as session: | |
| now = datetime.now(timezone.utc) | |
| cutoff = now - timedelta(hours=since_hours) | |
| rows = ( | |
| session.query(NewsRaw, NewsProcessed, NewsSentimentV2) | |
| .join(NewsProcessed, NewsProcessed.raw_id == NewsRaw.id) | |
| .outerjoin( | |
| NewsSentimentV2, | |
| NewsSentimentV2.news_processed_id == NewsProcessed.id, | |
| ) | |
| .filter(NewsRaw.published_at >= cutoff) | |
| .all() | |
| ) | |
| label_dist: dict[str, int] = {"BULLISH": 0, "BEARISH": 0, "NEUTRAL": 0} | |
| event_dist: dict[str, int] = {} | |
| channel_dist: dict[str, int] = {} | |
| publisher_acc: dict[str, dict[str, float]] = {} | |
| score_sum = 0.0 | |
| conf_sum = 0.0 | |
| rel_sum = 0.0 | |
| scored_count = 0 | |
| total = len(rows) | |
| for raw, _processed, sent in rows: | |
| ch = str(raw.source or "unknown") | |
| channel_dist[ch] = channel_dist.get(ch, 0) + 1 | |
| pub = _extract_publisher(raw.raw_payload) | |
| if pub: | |
| acc = publisher_acc.setdefault(pub, {"count": 0, "score_sum": 0.0}) | |
| acc["count"] += 1 | |
| if sent is not None and sent.final_score is not None: | |
| acc["score_sum"] += float(sent.final_score) | |
| if sent is None: | |
| continue | |
| scored_count += 1 | |
| if sent.label in label_dist: | |
| label_dist[sent.label] += 1 | |
| else: | |
| label_dist[sent.label] = label_dist.get(sent.label, 0) + 1 | |
| etype = sent.event_type or "unknown" | |
| event_dist[etype] = event_dist.get(etype, 0) + 1 | |
| if sent.final_score is not None: | |
| score_sum += float(sent.final_score) | |
| if sent.confidence_calibrated is not None: | |
| conf_sum += float(sent.confidence_calibrated) | |
| if sent.relevance_score is not None: | |
| rel_sum += float(sent.relevance_score) | |
| top_publishers = sorted( | |
| ( | |
| { | |
| "publisher": name, | |
| "count": int(data["count"]), | |
| "avg_final_score": ( | |
| round(float(data["score_sum"]) / float(data["count"]), 4) | |
| if data["count"] > 0 | |
| else 0.0 | |
| ), | |
| } | |
| for name, data in publisher_acc.items() | |
| ), | |
| key=lambda item: item["count"], | |
| reverse=True, | |
| )[:5] | |
| response = NewsStatsResponse( | |
| window_hours=since_hours, | |
| total_articles=total, | |
| scored_articles=scored_count, | |
| label_distribution=label_dist, | |
| event_type_distribution=event_dist, | |
| channel_distribution=channel_dist, | |
| top_publishers=top_publishers, | |
| avg_final_score=(score_sum / scored_count) if scored_count else None, | |
| avg_confidence=(conf_sum / scored_count) if scored_count else None, | |
| avg_relevance=(rel_sum / scored_count) if scored_count else None, | |
| generated_at=now.isoformat(), | |
| ) | |
| payload = response.model_dump() | |
| _news_stats_cache[since_hours] = (now_ts, payload) | |
| return payload | |
| async def get_news_item(processed_id: int): | |
| with SessionLocal() as session: | |
| row = ( | |
| session.query(NewsRaw, NewsProcessed, NewsSentimentV2) | |
| .join(NewsProcessed, NewsProcessed.raw_id == NewsRaw.id) | |
| .outerjoin( | |
| NewsSentimentV2, | |
| NewsSentimentV2.news_processed_id == NewsProcessed.id, | |
| ) | |
| .filter(NewsProcessed.id == processed_id) | |
| .first() | |
| ) | |
| if row is None: | |
| raise HTTPException(status_code=404, detail="Article not found") | |
| raw, processed, sentiment = row | |
| return NewsItem( | |
| id=int(processed.id), | |
| raw_id=int(raw.id), | |
| title=str(raw.title or ""), | |
| description=str(raw.description or "") or None, | |
| url=str(raw.url or "") or None, | |
| channel=str(raw.source or "unknown"), | |
| publisher=_extract_publisher(raw.raw_payload), | |
| source_feed=str(raw.source_feed or "") or None, | |
| published_at=raw.published_at.isoformat() if raw.published_at else None, | |
| fetched_at=raw.fetched_at.isoformat() if raw.fetched_at else None, | |
| language=str(processed.language or "") or None, | |
| sentiment=_build_news_sentiment_block(sentiment), | |
| ) | |