Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI for Pneumonia Detection - Hugging Face Spaces Deployment | |
| CI/CD enabled - auto-deploys from GitHub | |
| """ | |
| import io | |
| import time | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| IMAGE_SIZE = 224 | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| CLASS_NAMES = ["NORMAL", "PNEUMONIA"] | |
| MODEL_PATH = Path("models/best_model.pt") | |
| # ============================================================================= | |
| # Model Definition | |
| # ============================================================================= | |
| class PneumoniaClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = models.efficientnet_b0(weights=None) | |
| in_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3, inplace=True), | |
| nn.Linear(in_features, 1) | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) | |
| # ============================================================================= | |
| # Response Models | |
| # ============================================================================= | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| class PredictionResponse(BaseModel): | |
| prediction: str | |
| confidence: float | |
| probability: float | |
| processing_time_ms: float | |
| # ============================================================================= | |
| # App Setup | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="Pneumonia Detection API", | |
| description="Deep learning API for detecting pneumonia from chest X-rays", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================================= | |
| # Model Loading | |
| # ============================================================================= | |
| model = None | |
| device = None | |
| async def load_model(): | |
| global model, device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| if not MODEL_PATH.exists(): | |
| print(f"Warning: Model not found at {MODEL_PATH}") | |
| return | |
| model = PneumoniaClassifier() | |
| checkpoint = torch.load(MODEL_PATH, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| model.eval() | |
| print("Model loaded successfully") | |
| # ============================================================================= | |
| # Helper Functions | |
| # ============================================================================= | |
| def get_transforms(): | |
| return transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) | |
| ]) | |
| async def read_image(file: UploadFile) -> Image.Image: | |
| contents = await file.read() | |
| return Image.open(io.BytesIO(contents)).convert("RGB") | |
| def predict(image: Image.Image): | |
| transform = get_transforms() | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| prob = torch.sigmoid(output).item() | |
| pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0] | |
| confidence = prob if prob > 0.5 else 1 - prob | |
| return pred_class, confidence, prob | |
| # ============================================================================= | |
| # Endpoints | |
| # ============================================================================= | |
| async def root(): | |
| return {"message": "Pneumonia Detection API", "docs": "/docs"} | |
| async def health(): | |
| return HealthResponse( | |
| status="healthy" if model else "model_not_loaded", | |
| model_loaded=model is not None | |
| ) | |
| async def predict_endpoint(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| image = await read_image(file) | |
| start_time = time.time() | |
| pred_class, confidence, prob = predict(image) | |
| processing_time = (time.time() - start_time) * 1000 | |
| return PredictionResponse( | |
| prediction=pred_class, | |
| confidence=confidence, | |
| probability=prob, | |
| processing_time_ms=round(processing_time, 2) | |
| ) | |