Robb49's picture
Update app.py
94777e9 verified
raw
history blame contribute delete
No virus
5.28 kB
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision import transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
import torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
pd.DataFrame.iteritems = pd.DataFrame.items
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gradio as gr
classes = ['Fake_Copilot', 'Fake_DreamStudio', 'Fake_Gemini', 'Real']
d_path = 'dense.pth'
g_path = 'google.pth'
r_path = 'resnet.pth'
v_path = 'vgg13.pth'
cust_path = 'cust.pth'
dense_net = models.densenet161()
dense_net.classifier = nn.Linear(2208, len(classes), bias = True)
dense_net.load_state_dict(torch.load(d_path, map_location=torch.device('cpu')))
googlenet = models.googlenet()
googlenet.fc = nn.Linear(1024, len(classes), bias = True)
googlenet.load_state_dict(torch.load(g_path, map_location=torch.device('cpu')))
vgg13 = models.vgg13()
vgg13.classifier[6] = nn.Linear(4096, len(classes), bias = True)
vgg13.load_state_dict(torch.load(v_path, map_location=torch.device('cpu')))
res_net = models.resnet101()
res_net.fc = nn.Linear(2048, len(classes), bias = True)
res_net.load_state_dict(torch.load(r_path, map_location=torch.device('cpu')))
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 16, 3),
nn.SELU(),
nn.Conv2d(16, 16, 3),
nn.SELU(),
nn.Conv2d(16, 32, 3),
nn.SELU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 32, 3),
nn.SELU(),
nn.Conv2d(32, 64, 3),
nn.SELU(),
nn.Conv2d(64, 64, 3),
nn.SELU(),
nn.Conv2d(64, 128, 3),
nn.SELU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 128, 3),
nn.SELU(),
nn.Conv2d(128, 256, 3),
nn.SELU(),
nn.Conv2d(256, 256, 3),
nn.SELU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 128, 3),
nn.SELU(),
nn.Conv2d(128, 64, 3),
nn.SELU(),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(5184, 3200),
nn.Dropout(p=0.2),
nn.Linear(3200, 1000),
nn.Dropout(p=0.2),
nn.Linear(1000, 4)
)
def forward(self, x):
x = self.cnn(x)
return x
model_5 = CNN()
model_5.load_state_dict(torch.load(cust_path, map_location=torch.device('cpu')))
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
correct_pred_dict_list = [{'Fake_Copilot': 80, 'Fake_DreamStudio': 75, 'Fake_Gemini': 104, 'Real': 402},
{'Fake_Copilot': 78, 'Fake_DreamStudio': 87, 'Fake_Gemini': 112, 'Real': 398},
{'Fake_Copilot': 93, 'Fake_DreamStudio': 92, 'Fake_Gemini': 95, 'Real': 456},
{'Fake_Copilot': 89, 'Fake_DreamStudio': 72, 'Fake_Gemini': 108, 'Real': 388},
{'Fake_Copilot': 48, 'Fake_DreamStudio': 89, 'Fake_Gemini': 102, 'Real': 418}]
total_pred_dict_list = [{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464},
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464},
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464},
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464},
{'Fake_Copilot': 124,'Fake_DreamStudio': 137,'Fake_Gemini': 119,'Real': 464}]
chance_dict = {}
def one_prediction(img):
preds = {classname: 0 for classname in classes}
img = transform(img)
img.unsqueeze_(0)
models = [dense_net, googlenet, vgg13, res_net, model_5]
with torch.no_grad():
for i, model in enumerate(models):
model.eval()
output = model(img)
_, predicted = torch.max(output.data, 1)
preds[classes[predicted]] += 1 * (correct_pred_dict_list[i][classes[predicted]]/total_pred_dict_list[i][classes[predicted]])
for classname, count in preds.items():
chance = float(count) / sum(preds.values())
chance_dict[classname] = round(chance, 3)
return chance_dict
title = "Authentic vs Fake Image Classification"
description = "The purpose of this classifier is to identify if an image is real or fake, and if it is fake, to determine which generator created the image (Copilot, Gemini, or Dream Studio)"
article = "To determine the classification of the image, we utilize five models (DenseNet161, GoogleNet, VGG-13, ResNet101, and a custom created model)"
examples = [['real_example.png'],['copilot_example.png'], ['dream_example.png'], ['gemini_example.jpg']]
demo = gr.Interface(fn=one_prediction,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=4, label="Predictions"),
examples=examples,
title=title,
description=description,
article=article)
demo.launch(debug=False,
share=True)