Spaces:
Sleeping
Sleeping
from PIL import Image | |
import torchvision.transforms.functional as TF | |
from torchvision import transforms | |
import torchvision.models as models | |
from torchvision.datasets import ImageFolder | |
from torch.utils.data import DataLoader | |
from torch.utils.data import DataLoader, random_split | |
import torchvision | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
pd.DataFrame.iteritems = pd.DataFrame.items | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import gradio as gr | |
classes = ['Fake_Copilot', 'Fake_DreamStudio', 'Fake_Gemini', 'Real'] | |
d_path = 'dense.pth' | |
g_path = 'google.pth' | |
r_path = 'resnet.pth' | |
v_path = 'vgg13.pth' | |
cust_path = 'cust.pth' | |
dense_net = models.densenet161() | |
dense_net.classifier = nn.Linear(2208, len(classes), bias = True) | |
dense_net.load_state_dict(torch.load(d_path, map_location=torch.device('cpu'))) | |
googlenet = models.googlenet() | |
googlenet.fc = nn.Linear(1024, len(classes), bias = True) | |
googlenet.load_state_dict(torch.load(g_path, map_location=torch.device('cpu'))) | |
vgg13 = models.vgg13() | |
vgg13.classifier[6] = nn.Linear(4096, len(classes), bias = True) | |
vgg13.load_state_dict(torch.load(v_path, map_location=torch.device('cpu'))) | |
res_net = models.resnet101() | |
res_net.fc = nn.Linear(2048, len(classes), bias = True) | |
res_net.load_state_dict(torch.load(r_path, map_location=torch.device('cpu'))) | |
class CNN(nn.Module): | |
def __init__(self): | |
super(CNN, self).__init__() | |
self.cnn = nn.Sequential( | |
nn.Conv2d(3, 16, 3), | |
nn.SELU(), | |
nn.Conv2d(16, 16, 3), | |
nn.SELU(), | |
nn.Conv2d(16, 32, 3), | |
nn.SELU(), | |
nn.MaxPool2d(2, 2), | |
nn.Conv2d(32, 32, 3), | |
nn.SELU(), | |
nn.Conv2d(32, 64, 3), | |
nn.SELU(), | |
nn.Conv2d(64, 64, 3), | |
nn.SELU(), | |
nn.Conv2d(64, 128, 3), | |
nn.SELU(), | |
nn.MaxPool2d(2, 2), | |
nn.Conv2d(128, 128, 3), | |
nn.SELU(), | |
nn.Conv2d(128, 256, 3), | |
nn.SELU(), | |
nn.Conv2d(256, 256, 3), | |
nn.SELU(), | |
nn.MaxPool2d(2, 2), | |
nn.Conv2d(256, 128, 3), | |
nn.SELU(), | |
nn.Conv2d(128, 64, 3), | |
nn.SELU(), | |
nn.MaxPool2d(2, 2), | |
nn.Flatten(), | |
nn.Linear(5184, 3200), | |
nn.Dropout(p=0.2), | |
nn.Linear(3200, 1000), | |
nn.Dropout(p=0.2), | |
nn.Linear(1000, 4) | |
) | |
def forward(self, x): | |
x = self.cnn(x) | |
return x | |
model_5 = CNN() | |
model_5.load_state_dict(torch.load(cust_path, map_location=torch.device('cpu'))) | |
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
correct_pred_dict_list = [{'Fake_Copilot': 80, 'Fake_DreamStudio': 75, 'Fake_Gemini': 104, 'Real': 402}, | |
{'Fake_Copilot': 78, 'Fake_DreamStudio': 87, 'Fake_Gemini': 112, 'Real': 398}, | |
{'Fake_Copilot': 93, 'Fake_DreamStudio': 92, 'Fake_Gemini': 95, 'Real': 456}, | |
{'Fake_Copilot': 89, 'Fake_DreamStudio': 72, 'Fake_Gemini': 108, 'Real': 388}, | |
{'Fake_Copilot': 48, 'Fake_DreamStudio': 89, 'Fake_Gemini': 102, 'Real': 418}] | |
total_pred_dict_list = [{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464}, | |
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464}, | |
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464}, | |
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464}, | |
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464}] | |
chance_dict = {} | |
def one_prediction(img): | |
preds = {classname: 0 for classname in classes} | |
img = transform(img) | |
img.unsqueeze_(0) | |
models = [dense_net, googlenet, vgg13, res_net, model_5] | |
with torch.no_grad(): | |
for i, model in enumerate(models): | |
model.eval() | |
output = model(img) | |
_, predicted = torch.max(output.data, 1) | |
preds[classes[predicted]] += 1 * (correct_pred_dict_list[i][classes[predicted]]/total_pred_dict_list[i][classes[predicted]]) | |
for classname, count in preds.items(): | |
chance = float(count) / sum(preds.values()) | |
chance_dict[classname] = round(chance, 3) | |
return chance_dict | |
title = "Authentic vs Fake Image Classification" | |
description = "The purpose of this classifier is to identify if an image is real or fake, and if it is fake, to determine which generator created the image (Copilot, Gemini, or Dream Studio)" | |
article = "To determine the classification of the image, we utilize five models (DenseNet161, GoogleNet, VGG-13, ResNet101, and a custom created model)" | |
examples = [['real_example.png'],['copilot_example.png'], ['dream_example.png'], ['gemini_example.jpg']] | |
demo = gr.Interface(fn=one_prediction, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=4, label="Predictions"), | |
examples=examples, | |
title=title, | |
description=description, | |
article=article) | |
demo.launch(debug=False, | |
share=True) |