import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import time import os from concrete.fhe import Configuration from concrete.ml.torch.compile import compile_torch_model from custom_resnet import resnet18_custom # Assuming custom_resnet.py is in the same directory # Load class names (FLIPPED as ['Fake', 'Real']) class_names = ['Fake', 'Real'] # Fix the incorrect mapping # Load the trained model def load_model(model_path, device): print("load_model") model = resnet18_custom(weights=None) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, len(class_names)) # Assuming 2 classes: Fake and Real model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.eval() # Set model to evaluation mode return model def load_secure_model(model): print("Compiling secure model...") secure_model = compile_torch_model( model.to("cpu"), n_bits={"model_inputs": 4, "op_inputs": 3, "op_weights": 3, "model_outputs": 5}, rounding_threshold_bits={"n_bits": 7, "method": "APPROXIMATE"}, p_error=0.05, configuration=Configuration(enable_tlu_fusing=True, print_tlu_fusing=False, use_gpu=False), torch_inputset=torch.rand(10, 3, 224, 224) ) return secure_model # Load models model = load_model('models/deepfake_detection_model.pth', 'cpu') secure_model = load_secure_model(model) # Image preprocessing (match with the transforms used during training) data_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Prediction function def predict(image, mode): device = 'cpu' # Apply transformations to the input image image = Image.open(image).convert('RGB') image = data_transform(image).unsqueeze(0).to(device) # Add batch dimension # Inference with torch.no_grad(): start_time = time.time() if mode == "Fast": # Fast mode (less computation) outputs = model(image) elif mode == "Secure": # Secure mode (e.g., running multiple times for higher confidence) detached_input = image.detach().numpy() outputs = torch.from_numpy(secure_model.forward(detached_input, fhe="simulate")) _, preds = torch.max(outputs, 1) elapsed_time = time.time() - start_time predicted_class = class_names[preds[0]] return f"Predicted: {predicted_class}", f"Time taken: {elapsed_time:.2f} seconds" # Path to example images for "Fake" and "Real" classes example_images = [ ["data/fake/fake_1.jpeg", "Fast"], # Fake example ["data/real/real_1.jpg", "Fast"], # Real example ] # Gradio interface iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="filepath", label="Upload an Image"), # Image input gr.Radio(choices=["Fast", "Secure"], label="Inference Mode", value="Fast") # Inference mode ], outputs=[ gr.Textbox(label="Prediction"), # Prediction output gr.Textbox(label="Time Taken") # Time taken output ], examples=example_images, # Add default examples for Fake and Real images title="Deepfake Detection Model", description="Upload an image or select a sample and choose the inference mode (Fast or Secure)." ) if __name__ == "__main__": iface.launch(share=True)