Spaces:
Running
on
Zero
Running
on
Zero
import timm | |
import torch.nn as nn | |
def replace_last_layer(model, num_classes): | |
modules = list(model.named_modules()) | |
for name, module in reversed(modules): | |
if isinstance(module, nn.Linear): | |
in_features = module.in_features | |
new_fc = nn.Linear(in_features, num_classes) | |
new_fc.requires_grad = True | |
if "." in name: | |
parent_name, child_name = name.rsplit(".", 1) | |
parent = dict(model.named_modules())[parent_name] | |
setattr(parent, child_name, new_fc) | |
else: | |
setattr(model, name, new_fc) | |
break | |
elif isinstance(module, nn.Conv2d): | |
out_channels = module.out_channels | |
new_fc = nn.Conv2d( | |
out_channels, | |
num_classes, | |
kernel_size=module.kernel_size, | |
stride=module.stride, | |
padding=module.padding, | |
) | |
new_fc.requires_grad = True | |
if "." in name: | |
parent_name, child_name = name.rsplit(".", 1) | |
parent = dict(model.named_modules())[parent_name] | |
setattr(parent, child_name, new_fc) | |
else: | |
setattr(model, name, new_fc) | |
break | |
def enable_first_layer_grad(model): | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.Conv2d): | |
module.requires_grad = True | |
break | |
def create_model(key, in_chans=1, num_classes=1): | |
model = timm.create_model( | |
key, pretrained=False, in_chans=in_chans, num_classes=num_classes | |
) | |
for param in model.parameters(): | |
param.requires_grad = False | |
enable_first_layer_grad(model) | |
replace_last_layer(model, num_classes) | |
return model | |