Model and weights do not fit together

#3
by j35t3r - opened
import timm
model = timm.create_model("hf_hub:edadaltocg/resnet18_svhn", pretrained=True)
   Rraise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResNet:
        size mismatch for conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).

This seems to work for me:

        model = timm.create_model("hf_hub:edadaltocg/resnet18_svhn", pretrained=False)
        # override model
        model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        model.maxpool = nn.Identity()  # type: ignore
        model.fc = nn.Linear(512, 10)

        if args.mode == 'spatial':
            model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                "https://huggingface.co/edadaltocg/resnet18_svhn/resolve/main/pytorch_model.bin",
                map_location="cpu", 
                file_name="resnet18_cifar10.pth",
                )
            )

More testing needed

Please load the model in the following manner as suggested in the model card.

import detectors
import timm

model = timm.create_model("resnet18_svhn", pretrained=True)

The conv1 layer implemented in timm is oriented towards larger images, in which the convolutional kernel has a size of 7x7. To deal with smaller images, I swapped this layer to have a kernel with sizes 3x3, as you can see in the error message. That's why loading from the default resnet18 model in timm fails, and you should import detectors to load the correct model architecture module.

Please see https://github.com/edadaltocg/detectors/blob/3555ebada096694a69d7f030691ca233ec288190/src/detectors/models/resnet.py#L116for the modifications made in the model architecture.

Sign up or log in to comment