import os import torch from monai.networks.nets import DenseNet121, DenseNet169, DenseNet201, DenseNet264 from backbones.unet3d import UNet3D import utils.config def _freeze_layers_if_any(model, hparams): if len(hparams.frozen_layers) == 0: return model for (name, param) in model.named_parameters(): if any([name.startswith(to_freeze_name) for to_freeze_name in hparams.frozen_layers]): param.requires_grad = False return model def _replace_inplace_operations(model): # Grad-CAM compatibility for module in model.modules(): if hasattr(module, "inplace"): setattr(module, "inplace", False) return model def get_backbone(hparams): backbone = None in_channels = 1 + (hparams.mask == 'channel') + hparams.input_dim * hparams.coordinates if hparams.model_name.startswith('DenseNet'): if hparams.model_name == "DenseNet121": net_selection = DenseNet121 elif hparams.model_name == "DenseNet169": net_selection = DenseNet169 elif hparams.model_name == "DenseNet201": net_selection = DenseNet201 elif hparams.model_name == "DenseNet264": net_selection = DenseNet264 else: raise ValueError(f"Unknown DenseNet: {hparams.model_name}") backbone = net_selection( spatial_dims = hparams.input_dim, in_channels = in_channels, out_channels = hparams.num_classes - (hparams.loss == 'ordinal_regression'), dropout_prob = hparams.dropout, act = ("relu", {"inplace": False}) # inplace has to be set to False to enable use of Grad-CAM ) # ensure activation maps are not shrunk too much backbone.features.transition2.pool = torch.nn.Identity() backbone.features.transition3.pool = torch.nn.Identity() elif hparams.model_name.lower().startswith("resne"): # if you use pre-trained models, please add "pretrained_resnet" to the transforms hyperparameter backbone = torch.hub.load('pytorch/vision:v0.10.0', hparams.model_name, pretrained=hparams.model_name.lower().endswith('-pretrained')) # reset final fully connected layer to expected number of classes backbone.fc.out_features = hparams.num_classes - (hparams.loss == 'ordinal_regression') elif hparams.model_name == 'ModelsGenesis': backbone = UNet3D( in_channels=in_channels, input_size=hparams.input_size, n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression') ) weight_dir = os.path.join('data_sl', utils.config.globals["MODELS_GENESIS_PATH"]) checkpoint = torch.load(weight_dir,map_location=torch.device('cpu')) state_dict = checkpoint['state_dict'] unparalled_state_dict = {} for key in state_dict.keys(): unparalled_state_dict[key.replace("module.", "")] = state_dict[key] backbone.load_state_dict(unparalled_state_dict, strict=False) elif hparams.model_name == 'UNet3D': # this is the architecture of Models Genesis minus the pretraining backbone = UNet3D( in_channels=in_channels, input_size=hparams.input_size, n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression') ) else: raise NotImplementedError backbone = _replace_inplace_operations(backbone) backbone = _freeze_layers_if_any(backbone, hparams) return backbone