Spaces:
Build error
Build error
from fastapi import FastAPI, HTTPException, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from transformers import AutoFeatureExtractor | |
import timm | |
import numpy as np | |
import json | |
import base64 | |
from io import BytesIO | |
import uvicorn | |
app = FastAPI(title="VerifAI GradCAM API", description="API pour la détection d'images IA avec GradCAM") | |
# Configuration CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class AIDetectionGradCAM: | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = {} | |
self.feature_extractors = {} | |
self.target_layers = {} | |
# Initialiser les modèles | |
self._load_models() | |
def _load_models(self): | |
"""Charge les modèles pour la détection""" | |
try: | |
# Modèle Swin Transformer | |
model_name = "microsoft/swin-base-patch4-window7-224-in22k" | |
self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2) | |
self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name) | |
# Définir les couches cibles pour GradCAM | |
self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1] | |
# Mettre en mode évaluation | |
for model in self.models.values(): | |
model.eval() | |
model.to(self.device) | |
except Exception as e: | |
print(f"Erreur lors du chargement des modèles: {e}") | |
def _preprocess_image(self, image, model_type='swin'): | |
"""Prétraite l'image pour le modèle""" | |
if isinstance(image, str): | |
# Si c'est un chemin ou base64 | |
if image.startswith('data:image'): | |
# Décoder base64 | |
header, data = image.split(',', 1) | |
image_data = base64.b64decode(data) | |
image = Image.open(BytesIO(image_data)) | |
else: | |
image = Image.open(image) | |
# Convertir en RGB si nécessaire | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Redimensionner | |
image = image.resize((224, 224)) | |
# Normalisation standard | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
tensor = transform(image).unsqueeze(0).to(self.device) | |
return tensor, np.array(image) / 255.0 | |
def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'): | |
"""Génère la carte de saillance GradCAM""" | |
try: | |
model = self.models[model_type] | |
target_layers = self.target_layers[model_type] | |
# Créer l'objet GradCAM | |
cam = GradCAM(model=model, target_layers=target_layers) | |
# Générer la carte de saillance | |
grayscale_cam = cam(input_tensor=image_tensor, targets=None) | |
grayscale_cam = grayscale_cam[0, :] | |
# Superposer sur l'image originale | |
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
return cam_image | |
except Exception as e: | |
print(f"Erreur GradCAM: {e}") | |
return rgb_img * 255 | |
def predict_and_explain(self, image): | |
"""Prédiction avec explication GradCAM""" | |
try: | |
# Prétraitement | |
image_tensor, rgb_img = self._preprocess_image(image) | |
# Prédiction | |
with torch.no_grad(): | |
outputs = self.models['swin'](image_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
confidence = probabilities.max().item() | |
prediction = probabilities.argmax().item() | |
# Génération GradCAM | |
cam_image = self._generate_gradcam(image_tensor, rgb_img) | |
# Convertir l'image GradCAM en base64 | |
pil_image = Image.fromarray(cam_image.astype(np.uint8)) | |
buffer = BytesIO() | |
pil_image.save(buffer, format='PNG') | |
cam_base64 = base64.b64encode(buffer.getvalue()).decode() | |
# Résultats | |
result = { | |
'prediction': prediction, | |
'confidence': confidence, | |
'class_probabilities': { | |
'Real': probabilities[0][0].item(), | |
'AI-Generated': probabilities[0][1].item() | |
}, | |
'cam_image': f"data:image/png;base64,{cam_base64}", | |
'status': 'success' | |
} | |
return result | |
except Exception as e: | |
return {'status': 'error', 'message': str(e)} | |
# Initialiser le détecteur | |
detector = AIDetectionGradCAM() | |
async def root(): | |
return {"message": "VerifAI GradCAM API", "status": "running"} | |
async def predict_image(file: UploadFile = File(...)): | |
"""Endpoint pour analyser une image""" | |
try: | |
# Lire l'image | |
image_data = await file.read() | |
image = Image.open(BytesIO(image_data)) | |
# Analyser | |
result = detector.predict_and_explain(image) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def predict_base64(data: dict): | |
"""Endpoint pour analyser une image en base64""" | |
try: | |
if 'image' not in data: | |
raise HTTPException(status_code=400, detail="Champ 'image' requis") | |
image_b64 = data['image'] | |
# Analyser | |
result = detector.predict_and_explain(image_b64) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |