Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
from torchvision.models import ResNet50_Weights | |
import swanlab | |
from torch.utils.data import DataLoader | |
from load_datasets import DatasetLoader | |
import os | |
# Define train function | |
def train(model, device, train_dataloader, optimizer, criterion, epoch): | |
model.train() | |
for iter, (inputs, labels) in enumerate(train_dataloader): | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader), | |
loss.item())) | |
swanlab.log({"train_loss": loss.item()}) | |
# Define test function | |
def test(model, device, test_dataloader, epoch): | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for inputs, labels in test_dataloader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
accuracy = correct / total * 100 | |
print('Accuracy: {:.2f}%'.format(accuracy)) | |
swanlab.log({"test_acc": accuracy}) | |
if __name__ == "__main__": | |
num_epochs = 20 | |
lr = 1e-4 | |
batch_size = 16 | |
num_classes = 2 | |
try: | |
use_mps = torch.backends.mps.is_available() | |
except AttributeError: | |
use_mps = False | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif use_mps: | |
device = "mps" | |
else: | |
device = "cpu" | |
# Initialize swanlab | |
swanlab.init( | |
experiment_name="ResNet50", | |
description="Train ResNet50 for cat and dog classification.", | |
config={ | |
"model": "resnet50", | |
"optim": "Adam", | |
"lr": lr, | |
"batch_size": batch_size, | |
"num_epochs": num_epochs, | |
"num_class": num_classes, | |
"device": device, | |
} | |
) | |
TrainDataset = DatasetLoader("datasets/train.csv") | |
ValDataset = DatasetLoader("datasets/val.csv") | |
TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True) | |
ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False) | |
# Load the pre-trained ResNet50 model | |
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) | |
# Replace the last fully connected layer. | |
in_features = model.fc.in_features | |
model.fc = torch.nn.Linear(in_features, num_classes) | |
# Train | |
model.to(torch.device(device)) | |
criterion = torch.nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
for epoch in range(1, num_epochs + 1): | |
train(model, device, TrainDataLoader, optimizer, criterion, epoch) # Train for one epoch | |
if epoch % 4 == 0: # Test every 4 epochs | |
accuracy = test(model, device, ValDataLoader, epoch) | |
if not os.path.exists("checkpoint"): | |
os.makedirs("checkpoint") | |
torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth') | |
print("Training complete") |