Rbcloud commited on
Commit
a189a79
1 Parent(s): fd36140

Upload 4 files

Browse files
Files changed (4) hide show
  1. index.py +28 -0
  2. model.pt +3 -0
  3. train.py +121 -0
  4. vector_dict.py +1 -0
index.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image, ImageFile
4
+ import pandas as pd
5
+ import os
6
+ import math
7
+ from model import ConvolutionalNet
8
+ from collections import Counter
9
+ from vector_dict import vector_dict
10
+
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+
13
+ model = ConvolutionalNet()
14
+ model.load_state_dict(torch.load('model.pt'))
15
+
16
+ transform = transforms.Compose([
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
19
+ transforms.Resize((256, 256))
20
+ ])
21
+
22
+ def get_prediction(path):
23
+ img = Image.open(path)
24
+ with torch.no_grad():
25
+ pred = model(transform(img))
26
+ return vector_dict[torch.max(pred, 1)[1].item()]
27
+
28
+ print(get_prediction('data/test/Afghanistan/39841.png'))
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db39f0c877ce3ec72e5a1930bcd738c611867ae83aa3dc454090f7de4b843037
3
+ size 123517583
train.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms, datasets, models
6
+ from torchvision.utils import make_grid
7
+ import os
8
+ import time
9
+ from PIL import ImageFile
10
+ import math
11
+ from model import ConvolutionalNet
12
+
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ import matplotlib.pyplot as plt
18
+
19
+ train_transforms = transforms.Compose([
20
+ transforms.RandomRotation(10),
21
+ transforms.RandomHorizontalFlip(),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ ])
25
+
26
+ test_transform = transforms.Compose([
27
+ transforms.ToTensor(),
28
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
29
+ ])
30
+
31
+ train_dataset = datasets.ImageFolder(root='./data/train', transform=train_transforms)
32
+ test_dataset = datasets.ImageFolder(root='./data/test', transform=test_transform)
33
+
34
+ torch.manual_seed(42)
35
+ train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
36
+ test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)
37
+
38
+ class_names = train_dataset.classes
39
+
40
+ for images, labels in train_loader:
41
+ break
42
+
43
+ torch.manual_seed(101)
44
+ model = ConvolutionalNet()
45
+ criterion = nn.CrossEntropyLoss()
46
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
47
+
48
+ start_time = time.time()
49
+ epochs = 5
50
+
51
+ # BATCH LIMITS
52
+ max_trn_batch = 800
53
+ max_tst_batch = 300
54
+
55
+ train_losses = []
56
+ test_losses = []
57
+ train_correct = []
58
+ test_correct = []
59
+
60
+ for epoch in range(epochs):
61
+ trn_corr = 0
62
+ tst_corr = 0
63
+
64
+ for b, (X_train, y_train) in enumerate(train_loader):
65
+ # if b == max_trn_batch:
66
+ # break
67
+
68
+ y_pred = model(X_train)
69
+ loss = criterion(y_pred, y_train)
70
+
71
+ if b % 200 == 0:
72
+ print(f"Epoch: {epoch+1}/{epochs}\tBatch: {b+1}\tLoss: {loss.item()}")
73
+
74
+ predicted = torch.max(y_pred, 1)[1]
75
+ batch_corr = (predicted == y_train).sum()
76
+ trn_corr += batch_corr
77
+
78
+ optimizer.zero_grad()
79
+ loss.backward()
80
+ optimizer.step()
81
+
82
+ train_losses.append(loss)
83
+ train_correct.append(trn_corr)
84
+
85
+ # TEST
86
+ with torch.no_grad():
87
+ for b, (X_test, y_test) in enumerate(test_loader):
88
+ # if b == max_tst_batch:
89
+ # break
90
+
91
+ try:
92
+ y_pred = model(X_test)
93
+ except:
94
+ print("Error testing images")
95
+ continue
96
+ loss = criterion(y_pred, y_test)
97
+
98
+ predicted = torch.max(y_pred, 1)[1]
99
+ batch_corr = (predicted == y_test).sum()
100
+ tst_corr += batch_corr
101
+
102
+ test_losses.append(loss)
103
+ test_correct.append(tst_corr)
104
+
105
+ end_time = time.time()
106
+ total_time = end_time - start_time
107
+ print(f"Time taken: minutes: {math.floor(total_time / 60)} seconds: {math.floor(total_time % 60)}")
108
+
109
+ torch.save(model.state_dict(), 'model.pt')
110
+
111
+ plt.plot([x.detach().numpy() for x in train_losses], label='train loss')
112
+ plt.plot(test_losses, label='test loss')
113
+ plt.legend()
114
+ plt.plot()
115
+
116
+ plt.plot([t/80 for t in train_correct], label='train accuracy')
117
+ plt.plot([t/30 for t in test_correct], label='test accuracy')
118
+ plt.legend()
119
+ plt.plot()
120
+
121
+ print(f'Accuracy: {100*test_correct[-1].item()/1000}')
vector_dict.py ADDED
@@ -0,0 +1 @@
 
 
1
+ vector_dict = {0: 'Afghanistan', 1: 'Albania', 2: 'Algeria', 3: 'American Samoa', 4: 'Andorra', 5: 'Angola', 6: 'Anguilla', 7: 'Antarctica', 8: 'Antigua and Barbuda', 9: 'Argentina', 10: 'Armenia', 11: 'Aruba', 12: 'Australia', 13: 'Austria', 14: 'Azerbaijan', 15: 'Bahamas', 16: 'Bahrain', 17: 'Bangladesh', 18: 'Barbados', 19: 'Belarus', 20: 'Belgium', 21: 'Belize', 22: 'Benin', 23: 'Bermuda', 24: 'Bhutan', 25: 'Bolivia', 26: 'Bosnia and Herzegovina', 27: 'Botswana', 28: 'Bouvet Island', 29: 'Brazil', 30: 'British Indian Ocean Territory', 31: 'British Virgin Islands', 32: 'Brunei', 33: 'Bulgaria', 34: 'Burkina Faso', 35: 'Burundi', 36: 'Cambodia', 37: 'Cameroon', 38: 'Canada', 39: 'Cape Verde', 40: 'Caribbean Netherlands', 41: 'Cayman Islands', 42: 'Central African Republic', 43: 'Chad', 44: 'Chile', 45: 'China', 46: 'Christmas Island', 47: 'Cocos (Keeling) Islands', 48: 'Colombia', 49: 'Comoros', 50: 'Cook Islands', 51: 'Costa Rica', 52: 'Croatia', 53: 'Cuba', 54: 'Curaçao', 55: 'Cyprus', 56: 'Czechia', 57: 'DR Congo', 58: 'Denmark', 59: 'Djibouti', 60: 'Dominica', 61: 'Dominican Republic', 62: 'Ecuador', 63: 'Egypt', 64: 'El Salvador', 65: 'Equatorial Guinea', 66: 'Eritrea', 67: 'Estonia', 68: 'Eswatini', 69: 'Ethiopia', 70: 'Falkland Islands', 71: 'Faroe Islands', 72: 'Fiji', 73: 'Finland', 74: 'France', 75: 'French Guiana', 76: 'French Polynesia', 77: 'French Southern and Antarctic Lands', 78: 'Gabon', 79: 'Gambia', 80: 'Georgia', 81: 'Germany', 82: 'Ghana', 83: 'Gibraltar', 84: 'Greece', 85: 'Greenland', 86: 'Grenada', 87: 'Guadeloupe', 88: 'Guam', 89: 'Guatemala', 90: 'Guernsey', 91: 'Guinea', 92: 'Guinea-Bissau', 93: 'Guyana', 94: 'Haiti', 95: 'Heard Island and McDonald Islands', 96: 'Honduras', 97: 'Hong Kong', 98: 'Hungary', 99: 'Iceland', 100: 'India', 101: 'Indonesia', 102: 'Iran', 103: 'Iraq', 104: 'Ireland', 105: 'Isle of Man', 106: 'Israel', 107: 'Italy', 108: 'Ivory Coast', 109: 'Jamaica', 110: 'Japan', 111: 'Jersey', 112: 'Jordan', 113: 'Kazakhstan', 114: 'Kenya', 115: 'Kiribati', 116: 'Kosovo', 117: 'Kuwait', 118: 'Kyrgyzstan', 119: 'Laos', 120: 'Latvia', 121: 'Lebanon', 122: 'Lesotho', 123: 'Liberia', 124: 'Libya', 125: 'Liechtenstein', 126: 'Lithuania', 127: 'Luxembourg', 128: 'Macau', 129: 'Madagascar', 130: 'Malawi', 131: 'Malaysia', 132: 'Maldives', 133: 'Mali', 134: 'Malta', 135: 'Marshall Islands', 136: 'Martinique', 137: 'Mauritania', 138: 'Mauritius', 139: 'Mayotte', 140: 'Mexico', 141: 'Micronesia', 142: 'Moldova', 143: 'Monaco', 144: 'Mongolia', 145: 'Montenegro', 146: 'Montserrat', 147: 'Morocco', 148: 'Mozambique', 149: 'Myanmar', 150: 'Namibia', 151: 'Nauru', 152: 'Nepal', 153: 'Netherlands', 154: 'New Caledonia', 155: 'New Zealand', 156: 'Nicaragua', 157: 'Niger', 158: 'Nigeria', 159: 'Niue', 160: 'Norfolk Island', 161: 'North Korea', 162: 'North Macedonia', 163: 'Northern Mariana Islands', 164: 'Norway', 165: 'Oman', 166: 'Pakistan', 167: 'Palau', 168: 'Palestine', 169: 'Panama', 170: 'Papua New Guinea', 171: 'Paraguay', 172: 'Peru', 173: 'Philippines', 174: 'Pitcairn Islands', 175: 'Poland', 176: 'Portugal', 177: 'Puerto Rico', 178: 'Qatar', 179: 'Republic of the Congo', 180: 'Romania', 181: 'Russia', 182: 'Rwanda', 183: 'Réunion', 184: 'Saint Barthélemy', 185: 'Saint Helena Ascension and Tristan da Cunha', 186: 'Saint Kitts and Nevis', 187: 'Saint Lucia', 188: 'Saint Martin', 189: 'Saint Pierre and Miquelon', 190: 'Saint Vincent and the Grenadines', 191: 'Samoa', 192: 'San Marino', 193: 'Saudi Arabia', 194: 'Senegal', 195: 'Serbia', 196: 'Seychelles', 197: 'Sierra Leone', 198: 'Singapore', 199: 'Sint Maarten', 200: 'Slovakia', 201: 'Slovenia', 202: 'Solomon Islands', 203: 'Somalia', 204: 'South Africa', 205: 'South Georgia', 206: 'South Korea', 207: 'South Sudan', 208: 'Spain', 209: 'Sri Lanka', 210: 'Sudan', 211: 'Suriname', 212: 'Svalbard and Jan Mayen', 213: 'Sweden', 214: 'Switzerland', 215: 'Syria', 216: 'São Tomé and Príncipe', 217: 'Taiwan', 218: 'Tajikistan', 219: 'Tanzania', 220: 'Thailand', 221: 'Timor-Leste', 222: 'Togo', 223: 'Tokelau', 224: 'Tonga', 225: 'Trinidad and Tobago', 226: 'Tunisia', 227: 'Turkey', 228: 'Turkmenistan', 229: 'Turks and Caicos Islands', 230: 'Tuvalu', 231: 'Uganda', 232: 'Ukraine', 233: 'United Arab Emirates', 234: 'United Kingdom', 235: 'United States', 236: 'United States Minor Outlying Islands', 237: 'United States Virgin Islands', 238: 'Uruguay', 239: 'Uzbekistan', 240: 'Vanuatu', 241: 'Vatican City', 242: 'Venezuela', 243: 'Vietnam', 244: 'Wallis and Futuna', 245: 'Western Sahara', 246: 'Yemen', 247: 'Zambia', 248: 'Zimbabwe', 249: 'Åland Islands'}