| import os | |
| import torch | |
| import sys | |
| from torch import nn | |
| import torchvision | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| from model import MiniVisionV2 | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| save_path = "minivisionv2_model" | |
| batchsize = 256 | |
| learningrate = 1e-2 | |
| epoch = 50 | |
| if not os.path.exists(save_path): | |
| os.mkdir(save_path) | |
| writer = SummaryWriter("minivisionv2_logs") | |
| dataset = load_dataset("ylecun/mnist") | |
| transform_train = torchvision.transforms.Compose([ | |
| torchvision.transforms.RandomCrop(28, 2), | |
| torchvision.transforms.RandomRotation(10), | |
| torchvision.transforms.ToTensor() | |
| ]) | |
| transform_test = torchvision.transforms.Compose([ | |
| torchvision.transforms.ToTensor(), | |
| ]) | |
| def transforms_train(data): | |
| data["tensor"] = [transform_train(img) for img in data["image"]] | |
| return data | |
| def transforms_test(data): | |
| data["tensor"] = [transform_test(img) for img in data["image"]] | |
| return data | |
| train_dataset = dataset["train"].with_transform(transforms_train) | |
| test_dataset = dataset["test"].with_transform(transforms_test) | |
| def collate_fn(batch): | |
| return { | |
| "tensor": torch.stack([x["tensor"] for x in batch]), | |
| "label": torch.tensor([x["label"] for x in batch]) | |
| } | |
| train_loader = DataLoader(train_dataset, batchsize, True, collate_fn=collate_fn) | |
| test_loader = DataLoader(test_dataset, batchsize, False, collate_fn=collate_fn) | |
| minivisionv2 = MiniVisionV2() | |
| loss_fn = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.SGD(minivisionv2.parameters(), learningrate, 0.8) | |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, 0.5) | |
| for i in range(epoch): | |
| print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]["lr"]} ===============") | |
| minivisionv2.train() | |
| total_train_loss = 0 | |
| for data in tqdm(train_loader, file=sys.stdout): | |
| optimizer.zero_grad() | |
| imgs = data["tensor"] | |
| labels = data["label"] | |
| output = minivisionv2(imgs) | |
| loss = loss_fn(output, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_train_loss += loss.item() | |
| total_avg_train_loss = total_train_loss / len(train_loader) | |
| print(f"Train loss: {total_avg_train_loss}") | |
| writer.add_scalar("Train Loss", total_avg_train_loss, i) | |
| minivisionv2.eval() | |
| with torch.no_grad(): | |
| total_accuracy = 0 | |
| total_test_loss = 0 | |
| for data in tqdm(test_loader, file=sys.stdout): | |
| imgs = data["tensor"] | |
| labels = data["label"] | |
| output = minivisionv2(imgs) | |
| loss = loss_fn(output, labels) | |
| total_test_loss += loss | |
| accuracy = (output.argmax(1) == labels).sum() | |
| total_accuracy += accuracy.item() | |
| total_avg_test_loss = total_test_loss / len(test_loader) | |
| total_accuracy_percentage = round(float(total_accuracy / len(test_dataset) * 100), 2) | |
| print(f"Test loss: {total_avg_test_loss}") | |
| print(f"Test Accuracy Percentage: {total_accuracy_percentage}%") | |
| writer.add_scalar("Test Loss", total_avg_test_loss, i) | |
| writer.add_scalar("Test Accuracy Percentage", total_accuracy_percentage, i) | |
| torch.save(minivisionv2, f"./{save_path}/Mini-Vision-V2-Epoch-{i}.pth") | |
| print("Model Saved!") | |
| scheduler.step() | |
| writer.close() | |