|
import os, sys, glob, time, random, shutil, copy |
|
from tqdm import tqdm |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torchvision |
|
from torchvision import datasets, models, transforms |
|
import torch.utils.data as data |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.optim import lr_scheduler |
|
import torch.nn.functional as F |
|
from torchsummary import summary |
|
from matplotlib import pyplot as plt |
|
from torchvision.models import resnet18, ResNet18_Weights |
|
from PIL import Image, ImageFile |
|
from skimage import io |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
train_directory = 'dataset/train' |
|
valid_directory = 'dataset/val' |
|
|
|
|
|
bs = 64 |
|
|
|
num_epochs = 20 |
|
|
|
num_classes = 4 |
|
|
|
num_cpu = 8 |
|
|
|
|
|
image_transforms = { |
|
'train': transforms.Compose([ |
|
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), |
|
transforms.RandomRotation(degrees=15), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.CenterCrop(size=224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], |
|
[0.229, 0.224, 0.225]) |
|
]), |
|
'valid': transforms.Compose([ |
|
transforms.Resize(size=256), |
|
transforms.CenterCrop(size=224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], |
|
[0.229, 0.224, 0.225]) |
|
]) |
|
} |
|
|
|
|
|
dataset = { |
|
'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']), |
|
'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid']) |
|
} |
|
|
|
|
|
dataset_sizes = { |
|
'train':len(dataset['train']), |
|
'valid':len(dataset['valid']) |
|
} |
|
|
|
|
|
dataloaders = { |
|
'train':data.DataLoader(dataset['train'], batch_size=bs, shuffle=True, |
|
num_workers=num_cpu, pin_memory=True, drop_last=False), |
|
'valid':data.DataLoader(dataset['valid'], batch_size=bs, shuffle=False, |
|
num_workers=num_cpu, pin_memory=True, drop_last=False) |
|
} |
|
|
|
|
|
class_names = dataset['train'].classes |
|
print("Classes:", class_names) |
|
|
|
|
|
print("Training-set size:",dataset_sizes['train'], |
|
"\nValidation-set size:", dataset_sizes['valid']) |
|
|
|
modelname = 'resnet18' |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
weights = ResNet18_Weights.DEFAULT |
|
model = resnet18(weights=None) |
|
num_ftrs = model.fc.in_features |
|
model.fc = nn.Linear(num_ftrs, num_classes) |
|
|
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
print('Model Summary:-\n') |
|
for num, (name, param) in enumerate(model.named_parameters()): |
|
print(num, name, param.requires_grad ) |
|
summary(model, input_size=(3, 224, 224)) |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) |
|
|
|
|
|
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) |
|
|
|
since = time.time() |
|
|
|
best_model_wts = copy.deepcopy(model.state_dict()) |
|
best_acc = 0.0 |
|
|
|
for epoch in range(1, num_epochs+1): |
|
print('Epoch {}/{}'.format(epoch, num_epochs)) |
|
print('-' * 10) |
|
|
|
|
|
for phase in ['train', 'valid']: |
|
if phase == 'train': |
|
model.train() |
|
else: |
|
model.eval() |
|
|
|
running_loss = 0.0 |
|
running_corrects = 0 |
|
|
|
|
|
n = 0 |
|
stream = tqdm(dataloaders[phase]) |
|
for i, (inputs, labels) in enumerate(stream, start=1): |
|
inputs = inputs.to(device) |
|
labels = labels.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
with torch.set_grad_enabled(phase == 'train'): |
|
outputs = model(inputs) |
|
_, preds = torch.max(outputs, 1) |
|
loss = criterion(outputs, labels) |
|
|
|
|
|
if phase == 'train': |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
n += inputs.shape[0] |
|
running_loss += loss.item() * inputs.size(0) |
|
running_corrects += torch.sum(preds == labels.data) |
|
|
|
stream.set_description(f'Batch {i}/{len(dataloaders[phase])} | Loss: {running_loss/n:.4f}, Acc: {running_corrects/n:.4f}') |
|
|
|
if phase == 'train': |
|
scheduler.step() |
|
|
|
epoch_loss = running_loss / dataset_sizes[phase] |
|
epoch_acc = running_corrects.double() / dataset_sizes[phase] |
|
|
|
print('Epoch {}-{} Loss: {:.4f} Acc: {:.4f}'.format( |
|
epoch, phase, epoch_loss, epoch_acc)) |
|
|
|
|
|
if phase == 'valid' and epoch_acc >= best_acc: |
|
best_acc = epoch_acc |
|
best_model_wts = copy.deepcopy(model.state_dict()) |
|
print('Update best model!') |
|
|
|
time_elapsed = time.time() - since |
|
print('Training complete in {:.0f}m {:.0f}s'.format( |
|
time_elapsed // 60, time_elapsed % 60)) |
|
print('Best val Acc: {:4f}'.format(best_acc)) |
|
|
|
|
|
model.load_state_dict(best_model_wts) |
|
torch.save(model, 'logs/resnet18_4class.pth') |
|
torch.save(model.state_dict(), 'logs/resnet18_4class.tar') |
|
|