Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from facenet_pytorch import MTCNN, InceptionResnetV1 | |
from PIL import Image | |
import cv2 | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Header | |
from fastapi.responses import JSONResponse | |
from io import BytesIO | |
app = FastAPI() | |
API_KEY = "c50dd5ady0uRL0rdnSaVyrArYaN161edb06af8" | |
def get_api_key(api_key: str = Header(...)): | |
if api_key != API_KEY: | |
raise HTTPException(status_code=403, detail="Could not validate credentials") | |
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
# Load MTCNN for face detection | |
mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).eval() | |
# Initialize the new model (InceptionResnetV1 with vggface2 and 3-class classification) | |
model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=3, device=DEVICE) | |
DESTINATION_FILE_PATH = 'Model/3Class_32_epoch.pth' | |
# Load the model checkpoint | |
checkpoint = torch.load(DESTINATION_FILE_PATH, map_location=torch.device('cpu')) | |
model.to(DEVICE) | |
model.eval() | |
# Prediction function using the new model | |
def predict(input_image: Image.Image): | |
"""Predict the label of the input_image""" | |
if input_image.mode == 'RGBA': | |
input_image = input_image.convert('RGB') | |
# Detect face in the image | |
face = mtcnn(input_image) | |
if face is None: | |
raise Exception('No face detected') | |
face = face.unsqueeze(0) # Add batch dimension | |
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) | |
prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() | |
prev_face = prev_face.astype('uint8') | |
face = face.to(DEVICE).to(torch.float32) / 255.0 | |
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() | |
# Grad-CAM setup | |
target_layers = [model.block8.branch1[-1]] | |
cam = GradCAM(model=model, target_layers=target_layers) | |
targets = [ClassifierOutputTarget(0), ClassifierOutputTarget(1), ClassifierOutputTarget(2)] | |
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) | |
grayscale_cam = grayscale_cam[0, :] | |
visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True) | |
face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0) | |
# Inference | |
with torch.no_grad(): | |
output = torch.softmax(model(face).squeeze(0), dim=0) | |
class_indices = {0: 'real', 1: 'fake', 2: 'ai_generated'} | |
prediction = class_indices[torch.argmax(output).item()] | |
confidences = { | |
'real': output[0].item(), | |
'fake': output[1].item(), | |
'ai_generated': output[2].item() | |
} | |
return confidences, prediction, face_with_mask | |
# FastAPI prediction endpoint | |
async def predict_api(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): | |
image = Image.open(BytesIO(await file.read())) | |
try: | |
confidences, prediction, face_with_mask = predict(image) | |
_, buffer = cv2.imencode('.jpg', face_with_mask) | |
face_with_mask_encoded = buffer.tobytes() | |
return JSONResponse(content={ | |
"confidences": confidences, | |
"prediction": prediction, | |
"face_with_mask": face_with_mask_encoded.hex() | |
}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=400) | |