Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| import io | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 as cv | |
| from typing import List | |
| # Create FastAPI app | |
| app = FastAPI() | |
| # Load the pre-trained SafetyCNN model | |
| class SafetyCNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1) | |
| self.dropout = nn.Dropout(0.5) | |
| self.fc1 = nn.Linear(157 * 157 * 24, 64) | |
| self.fc2 = nn.Linear(64, 1) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = torch.flatten(x, 1) | |
| x = F.relu(self.fc1(x)) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| return x | |
| # Initialize model and load pre-trained weights | |
| safety_model = SafetyCNN() | |
| safety_model.load_state_dict(torch.load("safety_model.pth", map_location=torch.device('cpu'))) | |
| safety_model.eval() | |
| # Define transformations | |
| tta_transforms = [ | |
| transforms.Compose([]), | |
| transforms.Compose([transforms.RandomHorizontalFlip(p=1.0)]), | |
| transforms.Compose([transforms.RandomRotation(degrees=30)]), | |
| transforms.Compose([transforms.RandomResizedCrop(size=(640, 640), scale=(0.8, 1.0))]) | |
| ] | |
| # Utility function to classify image with TTA | |
| def classify_image_with_tta(image: Image.Image, model, tta_transforms: List[transforms.Compose], num_tta=4): | |
| # List to accumulate predictions from TTA versions | |
| augmented_predictions = [] | |
| # Apply each TTA transformation to the image | |
| for i in range(num_tta): | |
| tta_transform = tta_transforms[i % len(tta_transforms)] | |
| augmented_image = tta_transform(image) | |
| # Preprocess the image for the model | |
| transform = transforms.Compose([ | |
| transforms.Resize((640, 640)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| input_tensor = transform(augmented_image).unsqueeze(0) # Add batch dimension (1, 3, 640, 640) | |
| # Run the model on the augmented image | |
| with torch.no_grad(): | |
| output = model(input_tensor).squeeze(1) | |
| # Apply sigmoid to get the probability | |
| prob = torch.sigmoid(output).item() | |
| augmented_predictions.append(prob) | |
| # Average predictions over all TTA versions | |
| avg_prob = np.mean(augmented_predictions) | |
| # Set a threshold of 0.5 for binary classification | |
| prediction = 1 if avg_prob > 0.5 else 0 | |
| return prediction, avg_prob | |
| async def docs(): | |
| return RedirectResponse(url="/docs") | |
| # FastAPI endpoint to upload an image and classify it | |
| async def classify_image(file: UploadFile = File(...)): | |
| try: | |
| # Read the uploaded file | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| # Classify the image using the model | |
| prediction, avg_prob = classify_image_with_tta(image, safety_model, tta_transforms, num_tta=4) | |
| # Create a response message | |
| result = { | |
| "prediction": "Unsafe" if prediction == 1 else "Safe", | |
| "probability": avg_prob | |
| } | |
| return JSONResponse(content=result) | |
| # Optional: Run with uvicorn if needed | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # uvicorn.run(app, host="0.0.0.0", port=8000) | |
| # To run the server: | |
| # uvicorn safety_classifier:app --reload | |