Usage

import timm
import torchvision
MNIST_PATH = './datasets/mnist'


net = timm.create_model("resnet18", pretrained=False, num_classes=10)
net.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net.load_state_dict(
  torch.hub.load_state_dict_from_url(
    "https://huggingface.co/gpcarl123/resnet18_mnist/resolve/main/resnet18_mnist.pth",
    map_location="cpu",
    file_name="resnet18_mnist.pth",
  )
)

preprocessor = torchvision.transforms.Normalize((0.1307,), (0.3081,))
transform = transforms.Compose([transforms.ToTensor()])
test_set = datasets.MNIST(root=MNIST_PATH, train=False, download=True, transform=transform)
test_loader = data.DataLoader(test_set, batch_size=5, shuffle=False, num_workers=2)

for data, target in test_loader:
  print(net(preprocessor(data)))
  print(target)
  break
Downloads last month
0
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Dataset used to train gpcarl123/resnet18_mnist

Evaluation results