Libraries needed:
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from torch import nn
import matplotlib.pyplot as plt
to define a data loader
transformRes = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
trainsetRes = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformRes)
trainloaderRes64 = torch.utils.data.DataLoader(trainsetRes, batch_size=64, shuffle=True, num_workers=10)
testsetRes = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformRes)
testloaderRes64 = torch.utils.data.DataLoader(testsetRes, batch_size=64, shuffle=False, num_workers=10)
The model itself and training
import torchvision.models as models
# Load the pretrained model from pytorch
resnet50v2 = models.resnet50(pretrained=True)
# Freeze the parameters of the model
for param in resnet50v2.parameters():
param.requires_grad = True
# Change the final layer to match the number of classes in the CIFAR-10 dataset
num_ftrs = resnet50v2.fc.in_features
resnet50v2.fc = nn.Sequential(
nn.Linear(num_ftrs, 500),
nn.ReLU(),
nn.Linear(500, 200),
nn.Dropout(0.5),
nn.Linear(200,40),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(40,10),
nn.ReLU()
)
print("Model Info:")
print("ResNet50,Pretrained,weight adj. LR=0.01,Mom=0.3,WD=0.0001")
print("Schedule step=1,gamma=0.7, 20 epoches")
# Move the model to the GPU
resnet50v2 = resnet50v2.to(device, dtype=torch.float32)
optimizer = torch.optim.SGD(resnet50v2.parameters(), lr=0.01,momentum=0.3,weight_decay=0.0001)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
train_losses = []
test_losses = []
accuracies = []
train_acc = []
for epoch in range(20): # loop over the dataset multiple times
running_loss = 0.0
correctTrain = 0
totalTrain = 0
pbar = tqdm(enumerate(trainloaderRes16, 0), total=len(trainloaderRes16), desc="Epoch {}".format(epoch+1))
for i, data in pbar:
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device,dtype=torch.float32), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = resnet50v2(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
_, predicted_train = torch.max(outputs.data, 1)
totalTrain += labels.size(0)
correctTrain += (predicted_train == labels).sum().item()
pbar.set_postfix({'loss': running_loss/(i+1)})
train_accuracy = 100 * correctTrain / totalTrain
train_acc.append(train_accuracy)
print(f'Epoch {epoch + 1} loss: {running_loss / len(trainloaderRes16):.3f}')
# Start of testing phase
resnet50v2.eval() # Set the model to evaluation mode
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data in testloaderRes16:
images, labels = data[0].to(device,dtype=torch.float32), data[1].to(device)
outputs = resnet50v2(images)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Epoch {epoch + 1} Test loss: {test_loss / len(testloaderRes16):.3f}, Accuracy: {100 * correct / total:.2f}%')
#print the learning rate
print(f'Epoch {epoch + 1} Learning rate: {optimizer.param_groups[0]["lr"]}')
train_losses.append(running_loss / len(trainloaderRes16))
test_losses.append(test_loss / len(testloaderRes16))
accuracies.append(100 * correct / total)
resnet50v2.train() # Set the model back to training model
scheduler.step()
print('Finished Training')
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
plt.figure(figsize=(10, 5))
plt.plot(accuracies, label='Accuracy')
plt.plot(train_acc, label='Training Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.show()
- Downloads last month
- 0
Unable to determine this model's library. Check the
docs
.