Robb49 commited on
Commit
11b818b
1 Parent(s): 9f69f60

Upload Image_Origin_Classification.py

Browse files
Files changed (1) hide show
  1. Image_Origin_Classification.py +133 -0
Image_Origin_Classification.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[29]:
5
+
6
+
7
+ from PIL import Image
8
+ import torchvision.transforms.functional as TF
9
+ from torchvision import transforms
10
+ import torchvision.models as models
11
+ from torchvision.datasets import ImageFolder
12
+ from torch.utils.data import DataLoader
13
+ from torch.utils.data import DataLoader, random_split
14
+ import torchvision
15
+ import pandas as pd
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+ pd.DataFrame.iteritems = pd.DataFrame.items
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.optim as optim
23
+ import gradio as gr
24
+
25
+
26
+ # In[11]:
27
+
28
+
29
+ classes = ['Fake_Copilot', 'Fake_DreamStudio', 'Fake_Gemini', 'Real']
30
+
31
+
32
+ # In[16]:
33
+
34
+
35
+ d_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/dense.pth'
36
+ g_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/google.pth'
37
+ r_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/resnet.pth'
38
+ v_path = '/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/vgg13.pth'
39
+
40
+
41
+ # In[17]:
42
+
43
+
44
+ dense_net = models.densenet161()
45
+ dense_net.classifier = nn.Linear(2208, len(classes), bias = True)
46
+ dense_net.load_state_dict(torch.load(d_path))
47
+
48
+
49
+ # In[18]:
50
+
51
+
52
+ googlenet = models.googlenet()
53
+ googlenet.fc = nn.Linear(1024, len(classes), bias = True)
54
+ googlenet.load_state_dict(torch.load(g_path))
55
+
56
+
57
+ # In[19]:
58
+
59
+
60
+ vgg13 = models.vgg13()
61
+ vgg13.classifier[6] = nn.Linear(4096, len(classes), bias = True)
62
+ vgg13.load_state_dict(torch.load(v_path))
63
+
64
+
65
+ # In[20]:
66
+
67
+
68
+ res_net = models.resnet101()
69
+ res_net.fc = nn.Linear(2048, len(classes), bias = True)
70
+ res_net.load_state_dict(torch.load(r_path))
71
+
72
+
73
+ # In[24]:
74
+
75
+
76
+ transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
77
+
78
+
79
+ # In[27]:
80
+
81
+
82
+ def one_prediction(img):
83
+ preds = {classname: 0 for classname in classes}
84
+ #img = Image.open(path).convert('RGB')
85
+ img = transform(img)
86
+ img.unsqueeze_(0)
87
+ models = [dense_net, googlenet, vgg13, res_net]
88
+ #dense_net.eval()
89
+ with torch.no_grad():
90
+ for model in models:
91
+ model.eval()
92
+ output = model(img)
93
+ _, predicted = torch.max(output.data, 1)
94
+ preds[classes[predicted]] += 1
95
+ for classname, count in preds.items():
96
+ chance = float(count) / len(classes)
97
+ preds[classname] = chance
98
+ return preds
99
+
100
+
101
+ # In[28]:
102
+
103
+
104
+ #path = 'C:/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/One/24June Batch (80).png'
105
+ #img = Image.open(path).convert('RGB')
106
+ #one_prediction(img)
107
+
108
+
109
+ # In[30]:
110
+
111
+
112
+ title = "Real vs Fake Image Classification"
113
+ description = "Test."
114
+ article = "Test"
115
+ examples = [['C:/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/One/24June Batch (80).png'],
116
+ ['C:/Users/robb4/OneDrive/Desktop/DATA SCIENCE MPS/Summer 24/Deep Learning/Final Project/One/antarctica_0231.png']]
117
+
118
+ demo = gr.Interface(fn=one_prediction,
119
+ inputs=gr.Image(type="pil"),
120
+ outputs=gr.Label(num_top_classes=4, label="Predictions"),
121
+ examples=examples,
122
+ title=title,
123
+ description=description,
124
+ article=article)
125
+ demo.launch(debug=False,
126
+ share=True)
127
+
128
+
129
+ # In[ ]:
130
+
131
+
132
+
133
+