Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
# 画像データをテンソルに変換 | |
transform = transforms.Compose([transforms.ToTensor()]) | |
# MNISTのデータセットをダウンロード | |
train_dataset = datasets.MNIST(root="./data", | |
train=True, | |
download=True, | |
transform=transform) | |
test_dataset = datasets.MNIST(root="./data", | |
train=False, | |
download=True, | |
transform=transform) | |
# データローダの作成 | |
train_loader = DataLoader(train_dataset, # データセット | |
batch_size=100, # バッチサイズ | |
shuffle=True) # シャッフルするかどうか | |
test_loader = DataLoader(test_dataset, # データセット | |
batch_size=100, # バッチサイズ | |
shuffle=True) # シャッフルするかどうか | |
class MNISTModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.Linear(28 * 28, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, 10), # 0-9の数字のいずれかであるから、10クラス分類問題となる。よって出力は10次元 | |
nn.LogSoftmax(dim=1), | |
) | |
def forward(self, x: torch.Tensor): | |
x = x.view(-1, 28 * 28) # 画像データを1次元に変換 | |
x = self.model(x) | |
return x | |
from tqdm import tqdm | |
model = MNISTModel() | |
criterion = nn.NLLLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.003) | |
for epoch in range(10): | |
total_loss = 0 | |
for images, labels in tqdm(train_loader): | |
optimizer.zero_grad() | |
output = model(images) | |
loss = criterion(output, labels) | |
# ここでやっているのは、逆伝播の計算とパラメータの更新 | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"epoch: {epoch + 1}, loss: {total_loss}") | |
correct = 0 # 正解数 | |
total = 0 # テストデータの総数 | |
model.eval() # モデルを評価モードに変更 | |
with torch.no_grad(): | |
for images, labels in tqdm(test_loader): | |
output = model(images) | |
_, predicted = torch.max(output, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f"Accuracy: {100 * correct / total}%") | |
# モデルの保存 | |
torch.save(model.state_dict(), "mnist_model.pth") | |