resnet18_mnist / README.md
gpcarl123's picture
Update README.md
d65af8d
metadata
language:
  - en
library_name: timm
datasets:
  - mnist
metrics:
  - accuracy
model-index:
  - name: resnet18_mnist
    results:
      - task:
          type: image-classification
        dataset:
          name: MNIST
          type: mnist
        metrics:
          - type: accuracy
            value: 0.9936

Usage

import timm
import torchvision
MNIST_PATH = './datasets/mnist'


net = timm.create_model("resnet18", pretrained=False, num_classes=10)
net.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net.load_state_dict(
  torch.hub.load_state_dict_from_url(
    "https://huggingface.co/gpcarl123/resnet18_mnist/resolve/main/resnet18_mnist.pth",
    map_location="cpu",
    file_name="resnet18_mnist.pth",
  )
)

preprocessor = torchvision.transforms.Normalize((0.1307,), (0.3081,))
transform = transforms.Compose([transforms.ToTensor()])
test_set = datasets.MNIST(root=MNIST_PATH, train=False, download=True, transform=transform)
test_loader = data.DataLoader(test_set, batch_size=5, shuffle=False, num_workers=2)

for data, target in test_loader:
  print(net(preprocessor(data)))
  print(target)
  break