xray-reg / models /model_loader.py
SuperSecureHuman's picture
Upload 59 files
465d7e4 verified
import timm
import torch.nn as nn
def replace_last_layer(model, num_classes):
modules = list(model.named_modules())
for name, module in reversed(modules):
if isinstance(module, nn.Linear):
in_features = module.in_features
new_fc = nn.Linear(in_features, num_classes)
new_fc.requires_grad = True
if "." in name:
parent_name, child_name = name.rsplit(".", 1)
parent = dict(model.named_modules())[parent_name]
setattr(parent, child_name, new_fc)
else:
setattr(model, name, new_fc)
break
elif isinstance(module, nn.Conv2d):
out_channels = module.out_channels
new_fc = nn.Conv2d(
out_channels,
num_classes,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
)
new_fc.requires_grad = True
if "." in name:
parent_name, child_name = name.rsplit(".", 1)
parent = dict(model.named_modules())[parent_name]
setattr(parent, child_name, new_fc)
else:
setattr(model, name, new_fc)
break
def enable_first_layer_grad(model):
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
module.requires_grad = True
break
def create_model(key, in_chans=1, num_classes=1):
model = timm.create_model(
key, pretrained=False, in_chans=in_chans, num_classes=num_classes
)
for param in model.parameters():
param.requires_grad = False
enable_first_layer_grad(model)
replace_last_layer(model, num_classes)
return model