| | |
| | from __future__ import annotations |
| |
|
| | import warnings |
| | from typing import List, Literal, Optional, Tuple |
| | from config import MODEL_PATH, REAL_LABEL, API_KEY |
| | import joblib |
| | from fastapi import FastAPI, Header, HTTPException |
| | from helper import _combine |
| | from schemas import PredictOut, PredictBatchIn, PredictIn, PredictBatchOut |
| |
|
| | |
| | warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") |
| | warnings.filterwarnings("ignore", message=".*InconsistentVersionWarning.*") |
| | |
| | |
| | |
| | |
| | |
| | try: |
| | from sklearn.exceptions import InconsistentVersionWarning |
| | warnings.filterwarnings("ignore", category=InconsistentVersionWarning) |
| | except ImportError: |
| | |
| | pass |
| |
|
| | |
| | if 'PIPE' not in globals(): |
| | try: |
| | print("Loading model from:", MODEL_PATH) |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore") |
| | PIPE = joblib.load(MODEL_PATH) |
| | print("Model loaded successfully") |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | raise |
| | |
| | |
| | try: |
| | classes = list(PIPE.named_steps["clf"].classes_) |
| | except Exception: |
| | classes = list(getattr(PIPE, "classes_", [0, 1])) |
| | |
| | print(f"Model classes: {classes}") |
| | IDX_REAL = classes.index(REAL_LABEL) |
| | IDX_FAKE = classes.index(0) |
| | print(f"Real index: {IDX_REAL}, Fake index: {IDX_FAKE}") |
| | else: |
| | print("Model already loaded, skipping reload...") |
| |
|
| | |
| | |
| | |
| | def infer_one(inp: PredictIn) -> PredictOut: |
| | text_all = inp.text_all.strip().lower() if inp.text_all else _combine(inp.title, inp.text) |
| |
|
| | |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore") |
| | probs = PIPE.predict_proba([text_all])[0] |
| | |
| | prob_real = float(probs[IDX_REAL]) |
| | prob_fake = float(probs[IDX_FAKE]) |
| |
|
| | label = "real" if prob_real >= 0.5 else "fake" |
| |
|
| | return PredictOut( |
| | label=label, |
| | prob_real=prob_real, |
| | prob_fake=prob_fake, |
| | ) |
| |
|
| |
|
| | def infer_batch(items: List[PredictIn]) -> List[PredictOut]: |
| | return [infer_one(x) for x in items] |
| |
|
| |
|
| | |
| | |
| | |
| | app = FastAPI( |
| | title="SVM Fake/Real News Classifier", |
| | description="API for classifying news as real or fake using SVM with TF-IDF features", |
| | version="1.0.0" |
| | ) |
| |
|
| | @app.get("/") |
| | def root(): |
| | return { |
| | "message": "SVM Fake/Real News Classifier API", |
| | "endpoints": { |
| | "predict": "/predict", |
| | "predict_batch": "/predict_batch", |
| | "health": "/health" |
| | }, |
| | "model_info": { |
| | "classes": ["fake", "real"], |
| | "model_path": MODEL_PATH, |
| | "calibrated": True |
| | } |
| | } |
| |
|
| | @app.get("/health") |
| | def health_check(): |
| | return {"status": "healthy", "model_loaded": 'PIPE' in globals()} |
| |
|
| | @app.post("/predict", response_model=PredictOut) |
| | def predict(payload: PredictIn, x_api_key: str = Header(default="")): |
| | if x_api_key != API_KEY: |
| | raise HTTPException(status_code=401, detail="Unauthorized") |
| | return infer_one(payload) |
| |
|
| | @app.post("/predict_batch", response_model=PredictBatchOut) |
| | def predict_batch(payload: PredictBatchIn, x_api_key: str = Header(default="")): |
| | if x_api_key != API_KEY: |
| | raise HTTPException(status_code=401, detail="Unauthorized") |
| | return PredictBatchOut(results=infer_batch(payload.items)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | print("===== Application Ready =====") |
| | print("FastAPI app initialized successfully") |
| | print("API endpoints available at /predict and /predict_batch") |
| | print("API documentation at /docs") |
| | print("================================") |
| | uvicorn.run(app, host="0.0.0.0", port=6778) |
| |
|