{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\hiroga\\miniconda3\\envs\\pokemon-palworld-v2\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Dataset({\n", " features: ['image'],\n", " num_rows: 34945\n", "})\n", "Dataset({\n", " features: ['image'],\n", " num_rows: 48837\n", "})\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset_pal = load_dataset(\"imagefolder\", data_dir=\"../data/filtered/pal\", drop_labels=True, split=\"train\")\n", "print(dataset_pal)\n", "dataset_pokemon = load_dataset(\"imagefolder\", data_dir=\"../data/filtered/pokemon\", drop_labels=True, split=\"train\")\n", "print(dataset_pokemon)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dataset_pal = dataset_pal.map(lambda example: {'label': 0})\n", "dataset_pokemon = dataset_pokemon.map(lambda example: {'label': 1})" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from datasets import concatenate_datasets\n", "\n", "dataset = concatenate_datasets([dataset_pal, dataset_pokemon])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "\n", "compose = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.Lambda(lambda x: x.convert(\"RGB\")),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", "])\n", "transformed = dataset.map(lambda example: {\"image\": compose(example[\"image\"])}, batched=False)\n", "transformed.set_format(\"torch\")\n", "train_test_dataset = transformed.train_test_split(test_size=0.2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "batch_size = 128\n", "train_dataloader = DataLoader(train_test_dataset[\"train\"], batch_size=batch_size, shuffle=True)\n", "test_dataloader = DataLoader(train_test_dataset[\"test\"], batch_size=batch_size, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "label = [\"pal\", \"pokemon\"]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import wandb\n", "\n", "def train(model, optimizer, criterion, train_loader, test_loader, num_epochs, device):\n", " for epoch in range(num_epochs):\n", " model.train()\n", " running_loss = 0.0\n", " running_correct = 0\n", " total = 0\n", " \n", " for batch in train_loader:\n", " images, labels = batch[\"image\"], batch[\"label\"]\n", " images, labels = images.to(device), labels.to(device)\n", " optimizer.zero_grad()\n", " outputs = model(images)\n", " _, predicted = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " running_loss += loss.item()\n", " running_correct += torch.sum(predicted == labels.data)\n", " total += len(labels)\n", " \n", " model.eval()\n", " running_test_loss = 0.0\n", " running_test_correct = 0\n", " test_total = 0\n", " \n", " with torch.no_grad():\n", " for batch in train_loader:\n", " images, labels = batch[\"image\"], batch[\"label\"]\n", " images, labels = images.to(device), labels.to(device)\n", " outputs = model(images)\n", " _, predicted = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", " test_total += len(labels)\n", "\n", " running_test_loss += loss.item()\n", " running_test_correct += torch.sum(predicted == labels.data)\n", " \n", " log = {\n", " \"epoch\": epoch +1,\n", " \"train_loss\": running_loss / len(train_loader),\n", " \"train_acc\": running_correct / total,\n", " \"test_loss\": running_test_loss / len(test_loader),\n", " \"test_acc\": running_test_correct / test_total\n", " }\n", " print(log)\n", " wandb.log(log)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mhiroga\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.3" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in c:\\Users\\hiroga\\Documents\\GitHub\\til\\computer-science\\machine-learning\\_src\\pokemon-palworld\\notebooks\\wandb\\run-20240222_085816-qrvtmdob" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run glowing-rooster-39 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/hiroga/pokemon-palworld" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/hiroga/pokemon-palworld/runs/qrvtmdob" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\hiroga\\miniconda3\\envs\\pokemon-palworld-v2\\Lib\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", " warnings.warn(\n", "c:\\Users\\hiroga\\miniconda3\\envs\\pokemon-palworld-v2\\Lib\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n", " warnings.warn(msg)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'epoch': 1, 'train_loss': 0.05833915864536894, 'train_acc': tensor(0.9789, device='cuda:0'), 'test_loss': 0.11876845688253414, 'test_acc': tensor(0.9912, device='cuda:0')}\n", "{'epoch': 2, 'train_loss': 0.03478731791966929, 'train_acc': tensor(0.9890, device='cuda:0'), 'test_loss': 0.10162574984133244, 'test_acc': tensor(0.9923, device='cuda:0')}\n", "{'epoch': 3, 'train_loss': 0.03210982568926384, 'train_acc': tensor(0.9901, device='cuda:0'), 'test_loss': 0.09957705784894753, 'test_acc': tensor(0.9918, device='cuda:0')}\n", "{'epoch': 4, 'train_loss': 0.029581266190529667, 'train_acc': tensor(0.9904, device='cuda:0'), 'test_loss': 0.11058470348379652, 'test_acc': tensor(0.9923, device='cuda:0')}\n", "{'epoch': 5, 'train_loss': 0.029816618186993993, 'train_acc': tensor(0.9903, device='cuda:0'), 'test_loss': 0.08752415181166058, 'test_acc': tensor(0.9933, device='cuda:0')}\n", "{'epoch': 1, 'train_loss': 0.020498616995582294, 'train_acc': tensor(0.9938, device='cuda:0'), 'test_loss': 0.05274752890984421, 'test_acc': tensor(0.9964, device='cuda:0')}\n", "{'epoch': 2, 'train_loss': 0.013660616503107782, 'train_acc': tensor(0.9960, device='cuda:0'), 'test_loss': 0.03897114609966218, 'test_acc': tensor(0.9975, device='cuda:0')}\n", "{'epoch': 3, 'train_loss': 0.011321998002304863, 'train_acc': tensor(0.9969, device='cuda:0'), 'test_loss': 0.03173145028017817, 'test_acc': tensor(0.9981, device='cuda:0')}\n", "{'epoch': 4, 'train_loss': 0.009109769900116575, 'train_acc': tensor(0.9977, device='cuda:0'), 'test_loss': 0.025280542672569118, 'test_acc': tensor(0.9987, device='cuda:0')}\n", "{'epoch': 5, 'train_loss': 0.00814158867722251, 'train_acc': tensor(0.9977, device='cuda:0'), 'test_loss': 0.021063973103984942, 'test_acc': tensor(0.9989, device='cuda:0')}\n" ] } ], "source": [ "# Fine Tuning from ResNet18\n", "import torchvision.models as models\n", "import wandb\n", "from datetime import datetime\n", "from safetensors.torch import save_file\n", "\n", "model_name = \"ResNet18_FineTuned\"\n", "last_layer_learning_rate = 0.01\n", "last_layer_momentum = 0.9\n", "last_layer_epoches = 5\n", "full_layer_learning_rate = 0.001\n", "full_layer_momentum = 0.001\n", "full_layer_epoches = 5\n", "\n", "wandb.init(\n", " project=\"pokemon-palworld\",\n", " config={\n", " \"model_name\": model_name,\n", " \"labels\": label,\n", " \"last_layer_learning_rate\": last_layer_learning_rate,\n", " \"last_layer_momentum\": last_layer_momentum,\n", " \"last_layer_epochs\": last_layer_epoches,\n", " \"full_layer_learning_rate\": full_layer_learning_rate,\n", " \"full_layer_momentum\": full_layer_momentum,\n", " \"full_layer_epochs\": full_layer_epoches,\n", " \"architecture\": \"CNN\",\n", " \"dataset\": \"pokemon-palworld\",\n", " \"train_size\": len(train_dataloader.dataset),\n", " \"test_size\": len(test_dataloader.dataset),\n", " \"batch_size\": batch_size,\n", " }\n", ")\n", "\n", "model = models.resnet18(pretrained=True)\n", "for param in model.parameters():\n", " param.requires_grad = False\n", "model.fc = torch.nn.Linear(model.fc.in_features, len(label))\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "# Fine-tune the last layer for a few epochs\n", "optimizer = torch.optim.SGD(model.fc.parameters(), lr=last_layer_learning_rate, momentum=last_layer_momentum)\n", "train(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=last_layer_epoches, device=device)\n", "\n", "# Unfreeze all the layers and fine-tune the entire network for a few more epochs\n", "for param in model.parameters():\n", " param.requires_grad = True\n", "optimizer = torch.optim.SGD(model.parameters(), lr=full_layer_learning_rate, momentum=full_layer_momentum)\n", "train(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=full_layer_epoches, device=device)\n", "\n", "save_file(model.state_dict(), f\"../models/snapshots/{model_name}_epoch{last_layer_epoches}_{full_layer_epoches}_{datetime.now().strftime('%Y%m%d%H%M%S')}.safetensors\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Finishing last run (ID:qrvtmdob) before initializing another..." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "

Run history:


epoch▁▃▅▆█▁▃▅▆█
test_acc▁▂▂▂▃▆▇▇██
test_loss█▇▇▇▆▃▂▂▁▁
train_acc▁▅▅▅▅▇▇███
train_loss█▅▄▄▄▃▂▁▁▁

Run summary:


epoch5
test_acc0.9989
test_loss0.02106
train_acc0.99773
train_loss0.00814

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run glowing-rooster-39 at: https://wandb.ai/hiroga/pokemon-palworld/runs/qrvtmdob
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: .\\wandb\\run-20240222_085816-qrvtmdob\\logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Successfully finished last run (ID:qrvtmdob). Initializing new run:
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.3" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in c:\\Users\\hiroga\\Documents\\GitHub\\til\\computer-science\\machine-learning\\_src\\pokemon-palworld\\notebooks\\wandb\\run-20240222_114545-7hknvmoi" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run dazzling-paper-40 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/hiroga/pokemon-palworld" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/hiroga/pokemon-palworld/runs/7hknvmoi" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "ValueError", "evalue": "Expected input batch_size (98) to match target batch_size (128).", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[10], line 37\u001b[0m\n\u001b[0;32m 34\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(model\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39mlearning_rate)\n\u001b[0;32m 35\u001b[0m criterion \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mCrossEntropyLoss()\n\u001b[1;32m---> 37\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 39\u001b[0m save_file(model\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../models/snapshots/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_epoch\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.safetensors\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "Cell \u001b[1;32mIn[7], line 17\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(model, optimizer, criterion, train_loader, test_loader, num_epochs, device)\u001b[0m\n\u001b[0;32m 15\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(images)\n\u001b[0;32m 16\u001b[0m _, predicted \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmax(outputs, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 17\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mcriterion\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 19\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m 20\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n", "File \u001b[1;32mc:\\Users\\hiroga\\miniconda3\\envs\\pokemon-palworld-v2\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32mc:\\Users\\hiroga\\miniconda3\\envs\\pokemon-palworld-v2\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[1;32mc:\\Users\\hiroga\\miniconda3\\envs\\pokemon-palworld-v2\\Lib\\site-packages\\torch\\nn\\modules\\loss.py:1179\u001b[0m, in \u001b[0;36mCrossEntropyLoss.forward\u001b[1;34m(self, input, target)\u001b[0m\n\u001b[0;32m 1178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor, target: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m-> 1179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1180\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreduction\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1181\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32mc:\\Users\\hiroga\\miniconda3\\envs\\pokemon-palworld-v2\\Lib\\site-packages\\torch\\nn\\functional.py:3059\u001b[0m, in \u001b[0;36mcross_entropy\u001b[1;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[0;32m 3057\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 3058\u001b[0m reduction \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[1;32m-> 3059\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[1;31mValueError\u001b[0m: Expected input batch_size (98) to match target batch_size (128)." ] } ], "source": [ "# SimpleCNN\n", "import sys\n", "\n", "sys.path.append('..')\n", "\n", "import torch\n", "import wandb\n", "from safetensors.torch import save_file\n", "\n", "from src.SimpleCNN import SimpleCNN\n", "\n", "model_name = \"SimpleCNN\"\n", "learning_rate = 0.001\n", "epochs = 5\n", "image_size = 256\n", "\n", "wandb.init(\n", " project=\"pokemon-palworld\",\n", " config={\n", " \"model_name\": model_name,\n", " \"learning_rate\": learning_rate,\n", " \"architecture\": \"CNN\",\n", " \"dataset\": \"pokemon-palworld\",\n", " \"epochs\": epochs,\n", " \"image_size\": image_size,\n", " \"train_size\": len(train_dataloader.dataset),\n", " \"test_size\": len(test_dataloader.dataset),\n", " \"batch_size\": batch_size,\n", " }\n", ")\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = SimpleCNN(image_size=image_size).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "train(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=epochs, device=device)\n", "\n", "save_file(model.state_dict(), f\"../models/snapshots/{model_name}_epoch{epochs}.safetensors\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "til-machine-learning", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 2 }