|
|
|
|
|
|
|
|
import ClassUtils |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Subset |
|
|
import random |
|
|
from torchvision import models, transforms |
|
|
from torch.utils.data import DataLoader |
|
|
import time |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(device) |
|
|
|
|
|
|
|
|
num_classes = 2 |
|
|
batch_size = 256 |
|
|
epochs = 25 |
|
|
learning_rate = 5e-4 |
|
|
train_data_size = 25000 |
|
|
saved_state_dict_path = "MobileNetV3_test.pth" |
|
|
|
|
|
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) |
|
|
|
|
|
model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes) |
|
|
model = model.to(device) |
|
|
|
|
|
dataset = ClassUtils.CrosswalkDataset("zebra_annotations/classification_data") |
|
|
|
|
|
train_loader = DataLoader( |
|
|
Subset(dataset, random.sample(list(range(0, int(len(dataset) * 0.95))), train_data_size)), |
|
|
batch_size=batch_size, shuffle=True) |
|
|
test_loader = DataLoader( |
|
|
Subset(dataset, random.sample(list(range(int(len(dataset) * 0.95), len(dataset))), 12)), |
|
|
batch_size=batch_size, shuffle=False) |
|
|
|
|
|
criterion = nn.BCELoss() |
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(): |
|
|
model.train() |
|
|
start_time = time.time() |
|
|
for epoch in range(epochs): |
|
|
to_do = train_data_size |
|
|
running_loss = 0.0 |
|
|
for inputs, labels in train_loader: |
|
|
try: |
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
except: |
|
|
continue |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = torch.sigmoid(model(inputs)) |
|
|
loss = criterion(outputs, labels) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
running_loss += loss.item() |
|
|
to_do -= batch_size |
|
|
|
|
|
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}, time {time.time()- start_time}") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
def test_model(): |
|
|
model.eval() |
|
|
correct = 0 |
|
|
total = 0 |
|
|
with torch.no_grad(): |
|
|
for inputs, labels in test_loader: |
|
|
try: |
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
except: |
|
|
continue |
|
|
outputs = torch.sigmoid(model(inputs)) |
|
|
|
|
|
predicted = (outputs/100) > 0.5 |
|
|
for i in range(len(inputs)): |
|
|
plt.close() |
|
|
plt.imshow(torch.permute(inputs[i], (1, 2, 0)).cpu().detach().numpy()) |
|
|
plt.title(f"prediction of {outputs[i].tolist()[0]:.3f}%, {100 * predicted[i].tolist()[0]:.3f}%,\nactual: {labels[i].tolist()}") |
|
|
plt.axis("off") |
|
|
plt.show() |
|
|
|
|
|
total += labels.size(0) |
|
|
|
|
|
|
|
|
for prediction, label in zip(predicted, labels): |
|
|
correct += ((prediction[0]>50) == label[0]) |
|
|
|
|
|
print(f"Accuracy: {100 * correct / total}%") |
|
|
|
|
|
|
|
|
|
|
|
train = True |
|
|
if __name__ == "__main__": |
|
|
if train: |
|
|
train_model() |
|
|
torch.save(model.state_dict(), "mn3_vs55.pth") |
|
|
else: |
|
|
state_dictionairy = torch.load(saved_state_dict_path, weights_only=True) |
|
|
print(type(state_dictionairy)) |
|
|
model.load_state_dict(state_dictionairy) |
|
|
|
|
|
test_model() |
|
|
|
|
|
else: |
|
|
state_dictionairy = torch.load(saved_state_dict_path, weights_only=True) |
|
|
model.load_state_dict(state_dictionairy) |
|
|
print(f"Module: [{__name__}] has been loaded") |
|
|
|