RuntimeError

#1
by kashif HF staff - opened

When i load this model I get:

model = timm.create_model("hf-hub:edadaltocg/resnet18_cifar10", pretrained=True)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-1dd767c3077f> in <cell line: 1>()
----> 1 model = timm.create_model("hf-hub:edadaltocg/resnet18_cifar10", pretrained=True)

5 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   2039 
   2040         if len(error_msgs) > 0:
-> 2041             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2043         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ResNet:
    size mismatch for conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).

Hello @kashif ,

I'm working on something to solve this.
In the meanwhile, I invite you to load the model weights in the following manner:

model = timm.create_model("resnet18", pretrained=False)

# override model
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()  # type: ignore
model.fc = nn.Linear(512,  10)

model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                      "https://huggingface.co/edadaltocg/resnet18_cifar10/resolve/main/pytorch_model.bin",
                       map_location="cpu", 
                       file_name="resnet18_cifar10.pth",
             )
)

EDIT: remove features_dim variable and replaced with 512 following @j35t3r comment.

ok cool! thanks! looking forward!

the variable features_dim is not defined. I chose 512, which I read from the model.

Or just: add num classes. and delte the line above.
model = timm.create_model("resnet18", num_classes=10, pretrained=False)

Sign up or log in to comment