Image_Origin_Classification / Image_Origin_Classification.py
Robb49's picture
Upload Image_Origin_Classification.py
11b818b verified
raw
history blame contribute delete
No virus
3.42 kB
#!/usr/bin/env python
# coding: utf-8
# In[29]:
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision import transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
import torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
pd.DataFrame.iteritems = pd.DataFrame.items
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gradio as gr
# In[11]:
classes = ['Fake_Copilot', 'Fake_DreamStudio', 'Fake_Gemini', 'Real']
# In[16]:
d_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/dense.pth'
g_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/google.pth'
r_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/resnet.pth'
v_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/vgg13.pth'
# In[17]:
dense_net = models.densenet161()
dense_net.classifier = nn.Linear(2208, len(classes), bias = True)
dense_net.load_state_dict(torch.load(d_path))
# In[18]:
googlenet = models.googlenet()
googlenet.fc = nn.Linear(1024, len(classes), bias = True)
googlenet.load_state_dict(torch.load(g_path))
# In[19]:
vgg13 = models.vgg13()
vgg13.classifier[6] = nn.Linear(4096, len(classes), bias = True)
vgg13.load_state_dict(torch.load(v_path))
# In[20]:
res_net = models.resnet101()
res_net.fc = nn.Linear(2048, len(classes), bias = True)
res_net.load_state_dict(torch.load(r_path))
# In[24]:
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# In[27]:
def one_prediction(img):
preds = {classname: 0 for classname in classes}
#img = Image.open(path).convert('RGB')
img = transform(img)
img.unsqueeze_(0)
models = [dense_net, googlenet, vgg13, res_net]
#dense_net.eval()
with torch.no_grad():
for model in models:
model.eval()
output = model(img)
_, predicted = torch.max(output.data, 1)
preds[classes[predicted]] += 1
for classname, count in preds.items():
chance = float(count) / len(classes)
preds[classname] = chance
return preds
# In[28]:
#path = 'C:/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/One/24June Batch (80).png'
#img = Image.open(path).convert('RGB')
#one_prediction(img)
# In[30]:
title = "Real vs Fake Image Classification"
description = "Test."
article = "Test"
examples = [['C:/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/One/24June Batch (80).png'],
['C:/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/One/antarctica_0231.png']]
demo = gr.Interface(fn=one_prediction,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=4, label="Predictions"),
examples=examples,
title=title,
description=description,
article=article)
demo.launch(debug=False,
share=True)
# In[ ]: