Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,807 Bytes
465d7e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import timm
import torch.nn as nn
def replace_last_layer(model, num_classes):
modules = list(model.named_modules())
for name, module in reversed(modules):
if isinstance(module, nn.Linear):
in_features = module.in_features
new_fc = nn.Linear(in_features, num_classes)
new_fc.requires_grad = True
if "." in name:
parent_name, child_name = name.rsplit(".", 1)
parent = dict(model.named_modules())[parent_name]
setattr(parent, child_name, new_fc)
else:
setattr(model, name, new_fc)
break
elif isinstance(module, nn.Conv2d):
out_channels = module.out_channels
new_fc = nn.Conv2d(
out_channels,
num_classes,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
)
new_fc.requires_grad = True
if "." in name:
parent_name, child_name = name.rsplit(".", 1)
parent = dict(model.named_modules())[parent_name]
setattr(parent, child_name, new_fc)
else:
setattr(model, name, new_fc)
break
def enable_first_layer_grad(model):
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
module.requires_grad = True
break
def create_model(key, in_chans=1, num_classes=1):
model = timm.create_model(
key, pretrained=False, in_chans=in_chans, num_classes=num_classes
)
for param in model.parameters():
param.requires_grad = False
enable_first_layer_grad(model)
replace_last_layer(model, num_classes)
return model
|