| import torch.optim as optim |
| import torch |
| import torch.nn as nn |
| import torch.nn.parallel |
| import torch.optim |
| import torch.utils.data |
| import torch.utils.data.distributed |
| import torchvision.transforms as transforms |
| import torchvision.datasets as datasets |
| import torchvision.models |
| |
| from torch.autograd import Variable |
|
|
| |
| modellr = 1e-4 |
| BATCH_SIZE = 64 |
| EPOCHS = 20 |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| best_accuracy = 0 |
| best_epoch = 0 |
|
|
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
|
| ]) |
| transform_test = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
| ]) |
|
|
|
|
|
|
|
|
|
|
| dataset_train = datasets.ImageFolder('datasets/datasets/train', transform) |
| print(dataset_train.imgs) |
|
|
| print(dataset_train.class_to_idx) |
| dataset_test = datasets.ImageFolder('datasets/datasets/val', transform_test) |
|
|
|
|
|
|
| train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True) |
| test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False) |
|
|
|
|
|
|
| |
| criterion = nn.CrossEntropyLoss() |
| model = torchvision.models.resnet18(pretrained=True) |
| num_ftrs = model.fc.in_features |
| model.fc = nn.Linear(num_ftrs, 2) |
| model.to(DEVICE) |
|
|
| optimizer = optim.Adam(model.parameters(), lr=modellr) |
|
|
| |
| def adjust_learning_rate(optimizer, epoch): |
| """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" |
| modellrnew = modellr * (0.1 ** (epoch // 50)) |
| print("lr:", modellrnew) |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = modellrnew |
|
|
|
|
| |
| def train(model, device, train_loader, optimizer, epoch): |
| model.train() |
| sum_loss = 0 |
| total_num = len(train_loader.dataset) |
| print(total_num, len(train_loader)) |
| for batch_idx, (data, target) in enumerate(train_loader): |
| data, target = Variable(data).to(device), Variable(target).to(device) |
| output = model(data) |
| loss = criterion(output, target) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| print_loss = loss.data.item() |
| sum_loss += print_loss |
| if (batch_idx + 1) % 50 == 0: |
| print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
| epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), |
| 100. * (batch_idx + 1) / len(train_loader), loss.item())) |
| ave_loss = sum_loss / len(train_loader) |
| print('epoch:{},loss:{}'.format(epoch, ave_loss)) |
|
|
|
|
| |
|
|
| |
|
|
| def val(model, device, test_loader, epoch): |
| global best_accuracy, best_epoch |
| model.eval() |
| test_loss = 0 |
| correct = 0 |
| total_num = len(test_loader.dataset) |
| print(total_num, len(test_loader)) |
| with torch.no_grad(): |
| for data, target in test_loader: |
| data, target = Variable(data).to(device), Variable(target).to(device) |
| output = model(data) |
| loss = criterion(output, target) |
| _, pred = torch.max(output.data, 1) |
| correct += torch.sum(pred == target) |
| print_loss = loss.data.item() |
| test_loss += print_loss |
| correct = correct.data.item() |
| acc = correct / total_num |
| avgloss = test_loss / len(test_loader) |
| print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
| avgloss, correct, len(test_loader.dataset), 100 * acc)) |
| |
| if acc > best_accuracy: |
| best_accuracy, best_epoch |
| best_accuracy = acc |
| best_epoch = epoch |
| |
| torch.save(model, 'model_resnet18_epoch20_lr0.0001_best_epoch.pth') |
|
|
|
|
|
|
| |
| for epoch in range(1, EPOCHS + 1): |
| adjust_learning_rate(optimizer, epoch) |
| train(model, DEVICE, train_loader, optimizer, epoch) |
| val(model, DEVICE, test_loader, epoch) |
|
|
|
|
| print(f"Best model achieved at epoch {best_epoch} with accuracy: {best_accuracy * 100:.2f}%") |
|
|
|
|
|
|
|
|