|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import torchvision |
|
|
|
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import matplotlib.pyplot as plt |
|
|
import torchvision.models as models |
|
|
import torchvision.transforms as transforms |
|
|
import torchvision.datasets as datasets |
|
|
|
|
|
import time |
|
|
import copy |
|
|
import os |
|
|
|
|
|
|
|
|
batch_size = 128 |
|
|
learning_rate = 1e-3 |
|
|
|
|
|
|
|
|
transforms = transforms.Compose([transforms.ToTensor()]) |
|
|
|
|
|
|
|
|
train_dataset = datasets.ImageFolder( |
|
|
root="/input/fruits-360-dataset/fruits-360/Training", transform=transforms |
|
|
) |
|
|
|
|
|
test_dataset = datasets.ImageFolder( |
|
|
root="/input/fruits-360-dataset/fruits-360/Test", transform=transforms |
|
|
) |
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def imshow(inp, title=None): |
|
|
|
|
|
inp = inp.cpu() if device else inp |
|
|
inp = inp.numpy().transpose((1, 2, 0)) |
|
|
mean = np.array([0.485, 0.456, 0.406]) |
|
|
std = np.array([0.229, 0.224, 0.225]) |
|
|
inp = std * inp + mean |
|
|
inp = np.clip(inp, 0, 1) |
|
|
plt.imshow(inp) |
|
|
|
|
|
if title is not None: |
|
|
plt.title(title) |
|
|
plt.pause(0.001) |
|
|
|
|
|
|
|
|
images, labels = next(iter(train_dataloader)) |
|
|
print("images-size:", images.shape) |
|
|
|
|
|
out = torchvision.utils.make_grid(images) |
|
|
print("out-size:", out.shape) |
|
|
|
|
|
|
|
|
imshow(out, title=[train_dataset.classes[x] for x in labels]) |
|
|
|
|
|
|
|
|
net = models.resnet18(pretrained=True) |
|
|
|
|
|
net = net.cuda() if device else net |
|
|
|
|
|
net |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) |
|
|
|
|
|
|
|
|
def accuracy(out, labels): |
|
|
_, pred = torch.max(out, dim=1) |
|
|
return torch.sum(pred == labels).item() |
|
|
|
|
|
|
|
|
num_ftrs = net.fc.in_features |
|
|
net.fc = nn.Linear(num_ftrs, 128) |
|
|
net.fc = net.fc.cuda() if use_cuda else net.fc |
|
|
|
|
|
|
|
|
|
|
|
_epochs = 5 |
|
|
print_every = 10 |
|
|
valid_loss_min = np.Inf |
|
|
val_loss = [] |
|
|
val_acc = [] |
|
|
train_loss = [] |
|
|
train_acc = [] |
|
|
total_step = len(train_dataloader) |
|
|
|
|
|
for epoch in range(1, n_epochs + 1): |
|
|
running_loss = 0.0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
print(f"Epoch {epoch}\n") |
|
|
|
|
|
for batch_idx, (data_, target_) in enumerate(train_dataloader): |
|
|
data_, target_ = data_.to(device), target_.to(device) |
|
|
optimizer.zero_grad() |
|
|
outputs = net(data_) |
|
|
loss = criterion(outputs, target_) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
running_loss += loss.item() |
|
|
_, pred = torch.max(outputs, dim=1) |
|
|
correct += torch.sum(pred == target_).item() |
|
|
total += target_.size(0) |
|
|
|
|
|
if (batch_idx) % 20 == 0: |
|
|
print( |
|
|
"Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format( |
|
|
epoch, n_epochs, batch_idx, total_step, loss.item() |
|
|
) |
|
|
) |
|
|
|
|
|
train_acc.append(100 * correct / total) |
|
|
train_loss.append(running_loss / total_step) |
|
|
print( |
|
|
f"\ntrain-loss: {np.mean(train_loss):.4f}, train-acc: {(100 * correct/total):.4f}" |
|
|
) |
|
|
|
|
|
batch_loss = 0 |
|
|
total_t = 0 |
|
|
correct_t = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
net.eval() |
|
|
for data_t, target_t in test_dataloader: |
|
|
data_t, target_t = data_t.to(device), target_t.to(device) |
|
|
outputs_t = net(data_t) |
|
|
loss_t = criterion(outputs_t, target_t) |
|
|
batch_loss += loss_t.item() |
|
|
_, pred_t = torch.max(outputs_t, dim=1) |
|
|
correct_t += torch.sum(pred_t == target_t).item() |
|
|
total_t += target_t.size(0) |
|
|
|
|
|
val_acc.append(100 * correct_t / total_t) |
|
|
val_loss.append(batch_loss / len(test_dataloader)) |
|
|
|
|
|
network_learned = batch_loss < valid_loss_min |
|
|
print( |
|
|
f"validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t/total_t):.4f}\n" |
|
|
) |
|
|
|
|
|
if network_learned: |
|
|
valid_loss_min = batch_loss |
|
|
torch.save(net.state_dict(), "resnet.pt") |
|
|
print("Improvement-Detected, save-model") |
|
|
|
|
|
net.train() |
|
|
|