import torch import torch.nn.functional as F from facenet_pytorch import MTCNN, InceptionResnetV1 from PIL import Image import cv2 from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Header from fastapi.responses import JSONResponse from io import BytesIO app = FastAPI() API_KEY = "c50dd5ady0uRL0rdnSaVyrArYaN161edb06af8" def get_api_key(api_key: str = Header(...)): if api_key != API_KEY: raise HTTPException(status_code=403, detail="Could not validate credentials") DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' # Load MTCNN for face detection mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).eval() # Initialize the new model (InceptionResnetV1 with vggface2 and 3-class classification) model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=3, device=DEVICE) DESTINATION_FILE_PATH = 'Model/3Class_32_epoch.pth' # Load the model checkpoint checkpoint = torch.load(DESTINATION_FILE_PATH, map_location=torch.device('cpu')) model.to(DEVICE) model.eval() # Prediction function using the new model def predict(input_image: Image.Image): """Predict the label of the input_image""" if input_image.mode == 'RGBA': input_image = input_image.convert('RGB') # Detect face in the image face = mtcnn(input_image) if face is None: raise Exception('No face detected') face = face.unsqueeze(0) # Add batch dimension face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() prev_face = prev_face.astype('uint8') face = face.to(DEVICE).to(torch.float32) / 255.0 face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() # Grad-CAM setup target_layers = [model.block8.branch1[-1]] cam = GradCAM(model=model, target_layers=target_layers) targets = [ClassifierOutputTarget(0), ClassifierOutputTarget(1), ClassifierOutputTarget(2)] grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) grayscale_cam = grayscale_cam[0, :] visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True) face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0) # Inference with torch.no_grad(): output = torch.softmax(model(face).squeeze(0), dim=0) class_indices = {0: 'real', 1: 'fake', 2: 'ai_generated'} prediction = class_indices[torch.argmax(output).item()] confidences = { 'real': output[0].item(), 'fake': output[1].item(), 'ai_generated': output[2].item() } return confidences, prediction, face_with_mask # FastAPI prediction endpoint @app.post("/predict") async def predict_api(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): image = Image.open(BytesIO(await file.read())) try: confidences, prediction, face_with_mask = predict(image) _, buffer = cv2.imencode('.jpg', face_with_mask) face_with_mask_encoded = buffer.tobytes() return JSONResponse(content={ "confidences": confidences, "prediction": prediction, "face_with_mask": face_with_mask_encoded.hex() }) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=400)