aryanwinningson's picture
Update main.py
a78e5c8 verified
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Header
from fastapi.responses import JSONResponse
from io import BytesIO
app = FastAPI()
API_KEY = "c50dd5ady0uRL0rdnSaVyrArYaN161edb06af8"
def get_api_key(api_key: str = Header(...)):
if api_key != API_KEY:
raise HTTPException(status_code=403, detail="Could not validate credentials")
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# Load MTCNN for face detection
mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).eval()
# Initialize the new model (InceptionResnetV1 with vggface2 and 3-class classification)
model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=3, device=DEVICE)
DESTINATION_FILE_PATH = 'Model/3Class_32_epoch.pth'
# Load the model checkpoint
checkpoint = torch.load(DESTINATION_FILE_PATH, map_location=torch.device('cpu'))
model.to(DEVICE)
model.eval()
# Prediction function using the new model
def predict(input_image: Image.Image):
"""Predict the label of the input_image"""
if input_image.mode == 'RGBA':
input_image = input_image.convert('RGB')
# Detect face in the image
face = mtcnn(input_image)
if face is None:
raise Exception('No face detected')
face = face.unsqueeze(0) # Add batch dimension
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
prev_face = prev_face.astype('uint8')
face = face.to(DEVICE).to(torch.float32) / 255.0
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
# Grad-CAM setup
target_layers = [model.block8.branch1[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(0), ClassifierOutputTarget(1), ClassifierOutputTarget(2)]
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)
# Inference
with torch.no_grad():
output = torch.softmax(model(face).squeeze(0), dim=0)
class_indices = {0: 'real', 1: 'fake', 2: 'ai_generated'}
prediction = class_indices[torch.argmax(output).item()]
confidences = {
'real': output[0].item(),
'fake': output[1].item(),
'ai_generated': output[2].item()
}
return confidences, prediction, face_with_mask
# FastAPI prediction endpoint
@app.post("/predict")
async def predict_api(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
image = Image.open(BytesIO(await file.read()))
try:
confidences, prediction, face_with_mask = predict(image)
_, buffer = cv2.imencode('.jpg', face_with_mask)
face_with_mask_encoded = buffer.tobytes()
return JSONResponse(content={
"confidences": confidences,
"prediction": prediction,
"face_with_mask": face_with_mask_encoded.hex()
})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=400)