news_verification / src /images /image_demo.py
pmkhanh7890's picture
1st
22e1b62
raw
history blame
2.18 kB
import gradio as gr
import torchvision.transforms as transforms
from CNN_model_classifier import predict_cnn
from diffusion_model_classifier import (
ImageClassifier,
predict_single_image,
)
gr.set_static_paths(paths=["samples/"])
diffusion_model = (
"Diffusion/model_checkpoints/image-classifier-step=7007-val_loss=0.09.ckpt"
)
cnn_model = "CNN/model_checkpoints/blur_jpg_prob0.5.pth"
def get_prediction_diffusion(image):
model = ImageClassifier.load_from_checkpoint(diffusion_model)
prediction = predict_single_image(image, model)
print(prediction)
return (prediction >= 0.001, prediction)
def get_prediction_cnn(image):
prediction = predict_cnn(image, cnn_model)
return (prediction >= 0.5, prediction)
def predict(inp):
# Define the transformations for the image
transform = transforms.Compose(
[
transforms.Resize((224, 224)), # Image size expected by ResNet50
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
],
)
image_tensor = transform(inp)
pred_diff, prob_diff = get_prediction_diffusion(image_tensor)
pred_cnn, prob_cnn = get_prediction_cnn(image_tensor)
verdict = (
"AI Generated" if (pred_diff or pred_cnn) else "No GenAI detected"
)
return (
f"<h1>{verdict}</h1>"
f"<ul>"
f"<li>Diffusion detection score: {prob_diff:.2} "
f"{'(MATCH)' if pred_diff else ''}</li>"
f"<li>CNN detection score: {prob_cnn:.1%} "
f"{'(MATCH)' if pred_cnn else ''}</li>"
f"</ul>"
)
demo = gr.Interface(
title="AI-generated image detection",
description="Demo by NICT & Tokyo Techies ",
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.HTML(),
examples=[
["samples/fake_dalle.jpg", "Generated (Dall-E)"],
["samples/fake_midjourney.png", "Generated (MidJourney)"],
["samples/fake_stable.jpg", "Generated (Stable Diffusion)"],
["samples/fake_cnn.png", "Generated (GAN)"],
["samples/real.png", "Organic"],
],
)
demo.launch()