jannatulferdaws
commited on
Commit
•
a74b1a9
1
Parent(s):
c3c087d
code .py file
Browse files- notebook_code.py +149 -0
notebook_code.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[2]:
|
5 |
+
|
6 |
+
|
7 |
+
# import pytorch and machine learning stuff
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.optim as optim
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
# import other stuff
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
# In[3]:
|
19 |
+
|
20 |
+
|
21 |
+
# import resnet from torch
|
22 |
+
import torch.library
|
23 |
+
from torchvision.models import squeezenet1_1
|
24 |
+
from torchvision.models import resnet50
|
25 |
+
from torchvision.models import resnet18
|
26 |
+
from torchvision.models import mobilenet_v2
|
27 |
+
from torchvision import transforms
|
28 |
+
from torchvision.datasets import ImageFolder
|
29 |
+
from torch.utils.data import DataLoader
|
30 |
+
|
31 |
+
|
32 |
+
# In[4]:
|
33 |
+
|
34 |
+
|
35 |
+
class_num = 5
|
36 |
+
classes = ['Ak', 'Ala_Idris', 'Buzgulu', 'Dimnit', 'Nazli']
|
37 |
+
|
38 |
+
|
39 |
+
# In[5]:
|
40 |
+
|
41 |
+
|
42 |
+
model = mobilenet_v2(pretrained=True)
|
43 |
+
|
44 |
+
|
45 |
+
# In[9]:
|
46 |
+
|
47 |
+
|
48 |
+
print(model)
|
49 |
+
|
50 |
+
|
51 |
+
# In[7]:
|
52 |
+
|
53 |
+
|
54 |
+
transform = transforms.Compose([
|
55 |
+
transforms.RandomResizedCrop(224),
|
56 |
+
transforms.RandomHorizontalFlip(),
|
57 |
+
transforms.ToTensor(),
|
58 |
+
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
|
59 |
+
])
|
60 |
+
|
61 |
+
|
62 |
+
# In[8]:
|
63 |
+
|
64 |
+
|
65 |
+
training_set = ImageFolder('../data/train', transform=transform)
|
66 |
+
test_set = ImageFolder('../data/test', transform=transform)
|
67 |
+
val_set = ImageFolder('../data/val', transform=transform)
|
68 |
+
|
69 |
+
|
70 |
+
# In[47]:
|
71 |
+
|
72 |
+
|
73 |
+
batch_size = 8
|
74 |
+
epochs = 5
|
75 |
+
lr = 1e-5
|
76 |
+
loss_fn = nn.CrossEntropyLoss()
|
77 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
78 |
+
|
79 |
+
|
80 |
+
# In[48]:
|
81 |
+
|
82 |
+
|
83 |
+
train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
84 |
+
test_loader = DataLoader(test_set, batch_size=batch_size)
|
85 |
+
val_loader = DataLoader(val_set, batch_size=batch_size)
|
86 |
+
|
87 |
+
|
88 |
+
# In[49]:
|
89 |
+
|
90 |
+
|
91 |
+
model.classifier[1] = nn.Linear(in_features=1280, out_features=class_num)
|
92 |
+
|
93 |
+
|
94 |
+
# In[52]:
|
95 |
+
|
96 |
+
|
97 |
+
epochs = 1
|
98 |
+
|
99 |
+
|
100 |
+
# In[55]:
|
101 |
+
|
102 |
+
|
103 |
+
# train the model
|
104 |
+
for epoch in range(epochs):
|
105 |
+
model.train()
|
106 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
107 |
+
optimizer.zero_grad()
|
108 |
+
output = model(data)
|
109 |
+
print("Out: ", [a.argmax().item() for a in output])
|
110 |
+
print("Target: ", target)
|
111 |
+
loss = loss_fn(output, target)
|
112 |
+
loss.backward()
|
113 |
+
optimizer.step()
|
114 |
+
if batch_idx % 10 == 0:
|
115 |
+
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
116 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
117 |
+
100. * batch_idx / len(train_loader), loss.item()
|
118 |
+
))
|
119 |
+
|
120 |
+
# test the model
|
121 |
+
model.eval()
|
122 |
+
test_loss = 0
|
123 |
+
correct = 0
|
124 |
+
with torch.no_grad():
|
125 |
+
for data, target in test_loader:
|
126 |
+
output = model(data)
|
127 |
+
test_loss += loss_fn(output, target).item()
|
128 |
+
pred = output.argmax(dim=1, keepdim=True)
|
129 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
130 |
+
|
131 |
+
test_loss /= len(test_loader.dataset)
|
132 |
+
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
133 |
+
test_loss, correct, len(test_loader.dataset),
|
134 |
+
100. * correct / len(test_loader.dataset)
|
135 |
+
))
|
136 |
+
|
137 |
+
|
138 |
+
# In[51]:
|
139 |
+
|
140 |
+
|
141 |
+
model_scripted = torch.jit.script(model)
|
142 |
+
model_scripted.save('../models/mobilenet.pt')
|
143 |
+
|
144 |
+
|
145 |
+
# In[ ]:
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|