Spaces:
Running
Running
import gradio as gr | |
import torchvision.transforms as transforms | |
from CNN_model_classifier import predict_cnn | |
from diffusion_model_classifier import ( | |
ImageClassifier, | |
predict_single_image, | |
) | |
gr.set_static_paths(paths=["samples/"]) | |
diffusion_model = ( | |
"Diffusion/model_checkpoints/image-classifier-step=7007-val_loss=0.09.ckpt" | |
) | |
cnn_model = "CNN/model_checkpoints/blur_jpg_prob0.5.pth" | |
def get_prediction_diffusion(image): | |
model = ImageClassifier.load_from_checkpoint(diffusion_model) | |
prediction = predict_single_image(image, model) | |
print(prediction) | |
return (prediction >= 0.001, prediction) | |
def get_prediction_cnn(image): | |
prediction = predict_cnn(image, cnn_model) | |
return (prediction >= 0.5, prediction) | |
def predict(inp): | |
# Define the transformations for the image | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((224, 224)), # Image size expected by ResNet50 | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
], | |
) | |
image_tensor = transform(inp) | |
pred_diff, prob_diff = get_prediction_diffusion(image_tensor) | |
pred_cnn, prob_cnn = get_prediction_cnn(image_tensor) | |
verdict = ( | |
"AI Generated" if (pred_diff or pred_cnn) else "No GenAI detected" | |
) | |
return ( | |
f"<h1>{verdict}</h1>" | |
f"<ul>" | |
f"<li>Diffusion detection score: {prob_diff:.2} " | |
f"{'(MATCH)' if pred_diff else ''}</li>" | |
f"<li>CNN detection score: {prob_cnn:.1%} " | |
f"{'(MATCH)' if pred_cnn else ''}</li>" | |
f"</ul>" | |
) | |
demo = gr.Interface( | |
title="AI-generated image detection", | |
description="Demo by NICT & Tokyo Techies ", | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.HTML(), | |
examples=[ | |
["samples/fake_dalle.jpg", "Generated (Dall-E)"], | |
["samples/fake_midjourney.png", "Generated (MidJourney)"], | |
["samples/fake_stable.jpg", "Generated (Stable Diffusion)"], | |
["samples/fake_cnn.png", "Generated (GAN)"], | |
["samples/real.png", "Organic"], | |
], | |
) | |
demo.launch() | |