Upload 4 files
Browse files
index.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from PIL import Image, ImageFile
|
4 |
+
import pandas as pd
|
5 |
+
import os
|
6 |
+
import math
|
7 |
+
from model import ConvolutionalNet
|
8 |
+
from collections import Counter
|
9 |
+
from vector_dict import vector_dict
|
10 |
+
|
11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
12 |
+
|
13 |
+
model = ConvolutionalNet()
|
14 |
+
model.load_state_dict(torch.load('model.pt'))
|
15 |
+
|
16 |
+
transform = transforms.Compose([
|
17 |
+
transforms.ToTensor(),
|
18 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
19 |
+
transforms.Resize((256, 256))
|
20 |
+
])
|
21 |
+
|
22 |
+
def get_prediction(path):
|
23 |
+
img = Image.open(path)
|
24 |
+
with torch.no_grad():
|
25 |
+
pred = model(transform(img))
|
26 |
+
return vector_dict[torch.max(pred, 1)[1].item()]
|
27 |
+
|
28 |
+
print(get_prediction('data/test/Afghanistan/39841.png'))
|
model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db39f0c877ce3ec72e5a1930bcd738c611867ae83aa3dc454090f7de4b843037
|
3 |
+
size 123517583
|
train.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torchvision import transforms, datasets, models
|
6 |
+
from torchvision.utils import make_grid
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from PIL import ImageFile
|
10 |
+
import math
|
11 |
+
from model import ConvolutionalNet
|
12 |
+
|
13 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
train_transforms = transforms.Compose([
|
20 |
+
transforms.RandomRotation(10),
|
21 |
+
transforms.RandomHorizontalFlip(),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
24 |
+
])
|
25 |
+
|
26 |
+
test_transform = transforms.Compose([
|
27 |
+
transforms.ToTensor(),
|
28 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
29 |
+
])
|
30 |
+
|
31 |
+
train_dataset = datasets.ImageFolder(root='./data/train', transform=train_transforms)
|
32 |
+
test_dataset = datasets.ImageFolder(root='./data/test', transform=test_transform)
|
33 |
+
|
34 |
+
torch.manual_seed(42)
|
35 |
+
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
|
36 |
+
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)
|
37 |
+
|
38 |
+
class_names = train_dataset.classes
|
39 |
+
|
40 |
+
for images, labels in train_loader:
|
41 |
+
break
|
42 |
+
|
43 |
+
torch.manual_seed(101)
|
44 |
+
model = ConvolutionalNet()
|
45 |
+
criterion = nn.CrossEntropyLoss()
|
46 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
47 |
+
|
48 |
+
start_time = time.time()
|
49 |
+
epochs = 5
|
50 |
+
|
51 |
+
# BATCH LIMITS
|
52 |
+
max_trn_batch = 800
|
53 |
+
max_tst_batch = 300
|
54 |
+
|
55 |
+
train_losses = []
|
56 |
+
test_losses = []
|
57 |
+
train_correct = []
|
58 |
+
test_correct = []
|
59 |
+
|
60 |
+
for epoch in range(epochs):
|
61 |
+
trn_corr = 0
|
62 |
+
tst_corr = 0
|
63 |
+
|
64 |
+
for b, (X_train, y_train) in enumerate(train_loader):
|
65 |
+
# if b == max_trn_batch:
|
66 |
+
# break
|
67 |
+
|
68 |
+
y_pred = model(X_train)
|
69 |
+
loss = criterion(y_pred, y_train)
|
70 |
+
|
71 |
+
if b % 200 == 0:
|
72 |
+
print(f"Epoch: {epoch+1}/{epochs}\tBatch: {b+1}\tLoss: {loss.item()}")
|
73 |
+
|
74 |
+
predicted = torch.max(y_pred, 1)[1]
|
75 |
+
batch_corr = (predicted == y_train).sum()
|
76 |
+
trn_corr += batch_corr
|
77 |
+
|
78 |
+
optimizer.zero_grad()
|
79 |
+
loss.backward()
|
80 |
+
optimizer.step()
|
81 |
+
|
82 |
+
train_losses.append(loss)
|
83 |
+
train_correct.append(trn_corr)
|
84 |
+
|
85 |
+
# TEST
|
86 |
+
with torch.no_grad():
|
87 |
+
for b, (X_test, y_test) in enumerate(test_loader):
|
88 |
+
# if b == max_tst_batch:
|
89 |
+
# break
|
90 |
+
|
91 |
+
try:
|
92 |
+
y_pred = model(X_test)
|
93 |
+
except:
|
94 |
+
print("Error testing images")
|
95 |
+
continue
|
96 |
+
loss = criterion(y_pred, y_test)
|
97 |
+
|
98 |
+
predicted = torch.max(y_pred, 1)[1]
|
99 |
+
batch_corr = (predicted == y_test).sum()
|
100 |
+
tst_corr += batch_corr
|
101 |
+
|
102 |
+
test_losses.append(loss)
|
103 |
+
test_correct.append(tst_corr)
|
104 |
+
|
105 |
+
end_time = time.time()
|
106 |
+
total_time = end_time - start_time
|
107 |
+
print(f"Time taken: minutes: {math.floor(total_time / 60)} seconds: {math.floor(total_time % 60)}")
|
108 |
+
|
109 |
+
torch.save(model.state_dict(), 'model.pt')
|
110 |
+
|
111 |
+
plt.plot([x.detach().numpy() for x in train_losses], label='train loss')
|
112 |
+
plt.plot(test_losses, label='test loss')
|
113 |
+
plt.legend()
|
114 |
+
plt.plot()
|
115 |
+
|
116 |
+
plt.plot([t/80 for t in train_correct], label='train accuracy')
|
117 |
+
plt.plot([t/30 for t in test_correct], label='test accuracy')
|
118 |
+
plt.legend()
|
119 |
+
plt.plot()
|
120 |
+
|
121 |
+
print(f'Accuracy: {100*test_correct[-1].item()/1000}')
|
vector_dict.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
vector_dict = {0: 'Afghanistan', 1: 'Albania', 2: 'Algeria', 3: 'American Samoa', 4: 'Andorra', 5: 'Angola', 6: 'Anguilla', 7: 'Antarctica', 8: 'Antigua and Barbuda', 9: 'Argentina', 10: 'Armenia', 11: 'Aruba', 12: 'Australia', 13: 'Austria', 14: 'Azerbaijan', 15: 'Bahamas', 16: 'Bahrain', 17: 'Bangladesh', 18: 'Barbados', 19: 'Belarus', 20: 'Belgium', 21: 'Belize', 22: 'Benin', 23: 'Bermuda', 24: 'Bhutan', 25: 'Bolivia', 26: 'Bosnia and Herzegovina', 27: 'Botswana', 28: 'Bouvet Island', 29: 'Brazil', 30: 'British Indian Ocean Territory', 31: 'British Virgin Islands', 32: 'Brunei', 33: 'Bulgaria', 34: 'Burkina Faso', 35: 'Burundi', 36: 'Cambodia', 37: 'Cameroon', 38: 'Canada', 39: 'Cape Verde', 40: 'Caribbean Netherlands', 41: 'Cayman Islands', 42: 'Central African Republic', 43: 'Chad', 44: 'Chile', 45: 'China', 46: 'Christmas Island', 47: 'Cocos (Keeling) Islands', 48: 'Colombia', 49: 'Comoros', 50: 'Cook Islands', 51: 'Costa Rica', 52: 'Croatia', 53: 'Cuba', 54: 'Curaçao', 55: 'Cyprus', 56: 'Czechia', 57: 'DR Congo', 58: 'Denmark', 59: 'Djibouti', 60: 'Dominica', 61: 'Dominican Republic', 62: 'Ecuador', 63: 'Egypt', 64: 'El Salvador', 65: 'Equatorial Guinea', 66: 'Eritrea', 67: 'Estonia', 68: 'Eswatini', 69: 'Ethiopia', 70: 'Falkland Islands', 71: 'Faroe Islands', 72: 'Fiji', 73: 'Finland', 74: 'France', 75: 'French Guiana', 76: 'French Polynesia', 77: 'French Southern and Antarctic Lands', 78: 'Gabon', 79: 'Gambia', 80: 'Georgia', 81: 'Germany', 82: 'Ghana', 83: 'Gibraltar', 84: 'Greece', 85: 'Greenland', 86: 'Grenada', 87: 'Guadeloupe', 88: 'Guam', 89: 'Guatemala', 90: 'Guernsey', 91: 'Guinea', 92: 'Guinea-Bissau', 93: 'Guyana', 94: 'Haiti', 95: 'Heard Island and McDonald Islands', 96: 'Honduras', 97: 'Hong Kong', 98: 'Hungary', 99: 'Iceland', 100: 'India', 101: 'Indonesia', 102: 'Iran', 103: 'Iraq', 104: 'Ireland', 105: 'Isle of Man', 106: 'Israel', 107: 'Italy', 108: 'Ivory Coast', 109: 'Jamaica', 110: 'Japan', 111: 'Jersey', 112: 'Jordan', 113: 'Kazakhstan', 114: 'Kenya', 115: 'Kiribati', 116: 'Kosovo', 117: 'Kuwait', 118: 'Kyrgyzstan', 119: 'Laos', 120: 'Latvia', 121: 'Lebanon', 122: 'Lesotho', 123: 'Liberia', 124: 'Libya', 125: 'Liechtenstein', 126: 'Lithuania', 127: 'Luxembourg', 128: 'Macau', 129: 'Madagascar', 130: 'Malawi', 131: 'Malaysia', 132: 'Maldives', 133: 'Mali', 134: 'Malta', 135: 'Marshall Islands', 136: 'Martinique', 137: 'Mauritania', 138: 'Mauritius', 139: 'Mayotte', 140: 'Mexico', 141: 'Micronesia', 142: 'Moldova', 143: 'Monaco', 144: 'Mongolia', 145: 'Montenegro', 146: 'Montserrat', 147: 'Morocco', 148: 'Mozambique', 149: 'Myanmar', 150: 'Namibia', 151: 'Nauru', 152: 'Nepal', 153: 'Netherlands', 154: 'New Caledonia', 155: 'New Zealand', 156: 'Nicaragua', 157: 'Niger', 158: 'Nigeria', 159: 'Niue', 160: 'Norfolk Island', 161: 'North Korea', 162: 'North Macedonia', 163: 'Northern Mariana Islands', 164: 'Norway', 165: 'Oman', 166: 'Pakistan', 167: 'Palau', 168: 'Palestine', 169: 'Panama', 170: 'Papua New Guinea', 171: 'Paraguay', 172: 'Peru', 173: 'Philippines', 174: 'Pitcairn Islands', 175: 'Poland', 176: 'Portugal', 177: 'Puerto Rico', 178: 'Qatar', 179: 'Republic of the Congo', 180: 'Romania', 181: 'Russia', 182: 'Rwanda', 183: 'Réunion', 184: 'Saint Barthélemy', 185: 'Saint Helena Ascension and Tristan da Cunha', 186: 'Saint Kitts and Nevis', 187: 'Saint Lucia', 188: 'Saint Martin', 189: 'Saint Pierre and Miquelon', 190: 'Saint Vincent and the Grenadines', 191: 'Samoa', 192: 'San Marino', 193: 'Saudi Arabia', 194: 'Senegal', 195: 'Serbia', 196: 'Seychelles', 197: 'Sierra Leone', 198: 'Singapore', 199: 'Sint Maarten', 200: 'Slovakia', 201: 'Slovenia', 202: 'Solomon Islands', 203: 'Somalia', 204: 'South Africa', 205: 'South Georgia', 206: 'South Korea', 207: 'South Sudan', 208: 'Spain', 209: 'Sri Lanka', 210: 'Sudan', 211: 'Suriname', 212: 'Svalbard and Jan Mayen', 213: 'Sweden', 214: 'Switzerland', 215: 'Syria', 216: 'São Tomé and Príncipe', 217: 'Taiwan', 218: 'Tajikistan', 219: 'Tanzania', 220: 'Thailand', 221: 'Timor-Leste', 222: 'Togo', 223: 'Tokelau', 224: 'Tonga', 225: 'Trinidad and Tobago', 226: 'Tunisia', 227: 'Turkey', 228: 'Turkmenistan', 229: 'Turks and Caicos Islands', 230: 'Tuvalu', 231: 'Uganda', 232: 'Ukraine', 233: 'United Arab Emirates', 234: 'United Kingdom', 235: 'United States', 236: 'United States Minor Outlying Islands', 237: 'United States Virgin Islands', 238: 'Uruguay', 239: 'Uzbekistan', 240: 'Vanuatu', 241: 'Vatican City', 242: 'Venezuela', 243: 'Vietnam', 244: 'Wallis and Futuna', 245: 'Western Sahara', 246: 'Yemen', 247: 'Zambia', 248: 'Zimbabwe', 249: 'Åland Islands'}
|