diallomama commited on
Commit
3f5dbfd
1 Parent(s): 7748314

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -88
app.py DELETED
@@ -1,88 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
-
5
- class CNN(nn.Module):
6
- def __init__(self):
7
- super(CNN, self).__init__()
8
-
9
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
10
- self.relu1 = nn.ReLU()
11
- self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
12
-
13
- self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
14
- self.relu2 = nn.ReLU()
15
- self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
16
-
17
- self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
18
- self.relu3 = nn.ReLU()
19
- self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
20
-
21
- self.fc1 = nn.Linear(in_features=262144, out_features=512)
22
- self.relu4 = nn.ReLU()
23
-
24
- self.fc2 = nn.Linear(in_features=512, out_features=2)
25
-
26
- def forward(self, x):
27
- x = self.conv1(x)
28
- x = self.relu1(x)
29
- x = self.pool1(x)
30
-
31
- x = self.conv2(x)
32
- x = self.relu2(x)
33
- x = self.pool2(x)
34
-
35
- x = self.conv3(x)
36
- x = self.relu3(x)
37
- x = self.pool3(x)
38
-
39
- # Flatten
40
- x = x.reshape(x.shape[0], -1) #this work
41
-
42
- x = self.fc1(x)
43
- x = self.relu4(x)
44
-
45
- x = self.fc2(x)
46
-
47
- return x
48
-
49
- # Convert dataset to PyTorch dataset
50
- class AiornotDataset(Dataset):
51
- def __init__(self, image, transform=None):
52
- self.image = image
53
- self.transform = transform
54
-
55
- def __getitem__(self, idx):
56
- # Load image
57
- #img_byte = BytesIO(self.dataset[idx]['image'].tobytes())
58
- #img = self.dataset[idx]['image']
59
- # Apply transform
60
- if self.transform:
61
- img = self.transform(img)
62
-
63
- # Load label
64
- #label = self.dataset[idx]['label']
65
-
66
- return img
67
-
68
- def predict(image, model):
69
- img = AiornotDataset(image)
70
- model = CNN()
71
- model.load_state_dict(torch.load('./best_model.nn'))
72
- model.eval()
73
-
74
- pred = model(img)
75
- is_ai = torch.max(pred.data, 0)[1]
76
- if is_ai == 1:
77
- return "The input image is generated by an AI"
78
- return "The input image is not generated by an AI"
79
-
80
-
81
-
82
-
83
- gr.Interface(
84
- predict,
85
- inputs = gr.inpust.Image(label="Uploat an image", type="filepath"),
86
- #outputs = gr.outputs.Label(num_top_classes=2)
87
- outputs = "text"
88
- ).lunch()