Spaces:
Runtime error
Runtime error
File size: 2,665 Bytes
016a471 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 画像データをテンソルに変換
transform = transforms.Compose([transforms.ToTensor()])
# MNISTのデータセットをダウンロード
train_dataset = datasets.MNIST(root="./data",
train=True,
download=True,
transform=transform)
test_dataset = datasets.MNIST(root="./data",
train=False,
download=True,
transform=transform)
# データローダの作成
train_loader = DataLoader(train_dataset, # データセット
batch_size=100, # バッチサイズ
shuffle=True) # シャッフルするかどうか
test_loader = DataLoader(test_dataset, # データセット
batch_size=100, # バッチサイズ
shuffle=True) # シャッフルするかどうか
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10), # 0-9の数字のいずれかであるから、10クラス分類問題となる。よって出力は10次元
nn.LogSoftmax(dim=1),
)
def forward(self, x: torch.Tensor):
x = x.view(-1, 28 * 28) # 画像データを1次元に変換
x = self.model(x)
return x
from tqdm import tqdm
model = MNISTModel()
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
for epoch in range(10):
total_loss = 0
for images, labels in tqdm(train_loader):
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
# ここでやっているのは、逆伝播の計算とパラメータの更新
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"epoch: {epoch + 1}, loss: {total_loss}")
correct = 0 # 正解数
total = 0 # テストデータの総数
model.eval() # モデルを評価モードに変更
with torch.no_grad():
for images, labels in tqdm(test_loader):
output = model(images)
_, predicted = torch.max(output, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total}%")
# モデルの保存
torch.save(model.state_dict(), "mnist_model.pth")
|