nnUNet_calvingfront_detection
/
nnunet
/training
/network_training
/nnUNet_variants
/architectural_variants
/nnUNetTrainerV2_ReLU.py
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from nnunet.network_architecture.generic_UNet import Generic_UNet | |
from nnunet.network_architecture.initialization import InitWeights_He | |
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2 | |
from nnunet.utilities.nd_softmax import softmax_helper | |
from torch import nn | |
class nnUNetTrainerV2_ReLU(nnUNetTrainerV2): | |
def initialize_network(self): | |
if self.threeD: | |
conv_op = nn.Conv3d | |
dropout_op = nn.Dropout3d | |
norm_op = nn.InstanceNorm3d | |
else: | |
conv_op = nn.Conv2d | |
dropout_op = nn.Dropout2d | |
norm_op = nn.InstanceNorm2d | |
norm_op_kwargs = {'eps': 1e-5, 'affine': True} | |
dropout_op_kwargs = {'p': 0, 'inplace': True} | |
net_nonlin = nn.ReLU | |
net_nonlin_kwargs = {'inplace': True} | |
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, | |
len(self.net_num_pool_op_kernel_sizes), | |
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, | |
net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(0), | |
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) | |
if torch.cuda.is_available(): | |
self.network.cuda() | |
self.network.inference_apply_nonlin = softmax_helper | |