Usage
import timm
import torchvision
MNIST_PATH = './datasets/mnist'
net = timm.create_model("resnet18", pretrained=False, num_classes=10)
net.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net.load_state_dict(
torch.hub.load_state_dict_from_url(
"https://huggingface.co/gpcarl123/resnet18_mnist/resolve/main/resnet18_mnist.pth",
map_location="cpu",
file_name="resnet18_mnist.pth",
)
)
preprocessor = torchvision.transforms.Normalize((0.1307,), (0.3081,))
transform = transforms.Compose([transforms.ToTensor()])
test_set = datasets.MNIST(root=MNIST_PATH, train=False, download=True, transform=transform)
test_loader = data.DataLoader(test_set, batch_size=5, shuffle=False, num_workers=2)
for data, target in test_loader:
print(net(preprocessor(data)))
print(target)
break