# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm

from .resnet import BasicBlock, Bottleneck


class HRModule(BaseModule):
    """High-Resolution Module for HRNet.

    In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
    is in this module.
    """
    def __init__(self,
                 num_branches,
                 blocks,
                 num_blocks,
                 in_channels,
                 num_channels,
                 multiscale_output=True,
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 block_init_cfg=None,
                 init_cfg=None):
        super(HRModule, self).__init__(init_cfg)
        self.block_init_cfg = block_init_cfg
        self._check_branches(num_branches, num_blocks, in_channels,
                             num_channels)

        self.in_channels = in_channels
        self.num_branches = num_branches

        self.multiscale_output = multiscale_output
        self.norm_cfg = norm_cfg
        self.conv_cfg = conv_cfg
        self.with_cp = with_cp
        self.branches = self._make_branches(num_branches, blocks, num_blocks,
                                            num_channels)
        self.fuse_layers = self._make_fuse_layers()
        self.relu = nn.ReLU(inplace=False)

    def _check_branches(self, num_branches, num_blocks, in_channels,
                        num_channels):
        if num_branches != len(num_blocks):
            error_msg = f'NUM_BRANCHES({num_branches}) ' \
                        f'!= NUM_BLOCKS({len(num_blocks)})'
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = f'NUM_BRANCHES({num_branches}) ' \
                        f'!= NUM_CHANNELS({len(num_channels)})'
            raise ValueError(error_msg)

        if num_branches != len(in_channels):
            error_msg = f'NUM_BRANCHES({num_branches}) ' \
                        f'!= NUM_INCHANNELS({len(in_channels)})'
            raise ValueError(error_msg)

    def _make_one_branch(self,
                         branch_index,
                         block,
                         num_blocks,
                         num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
                self.in_channels[branch_index] != \
                num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                build_conv_layer(self.conv_cfg,
                                 self.in_channels[branch_index],
                                 num_channels[branch_index] * block.expansion,
                                 kernel_size=1,
                                 stride=stride,
                                 bias=False),
                build_norm_layer(self.norm_cfg, num_channels[branch_index] *
                                 block.expansion)[1])

        layers = []
        layers.append(
            block(self.in_channels[branch_index],
                  num_channels[branch_index],
                  stride,
                  downsample=downsample,
                  with_cp=self.with_cp,
                  norm_cfg=self.norm_cfg,
                  conv_cfg=self.conv_cfg,
                  init_cfg=self.block_init_cfg))
        self.in_channels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(
                block(self.in_channels[branch_index],
                      num_channels[branch_index],
                      with_cp=self.with_cp,
                      norm_cfg=self.norm_cfg,
                      conv_cfg=self.conv_cfg,
                      init_cfg=self.block_init_cfg))

        return Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []

        for i in range(num_branches):
            branches.append(
                self._make_one_branch(i, block, num_blocks, num_channels))

        return ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        in_channels = self.in_channels
        fuse_layers = []
        num_out_branches = num_branches if self.multiscale_output else 1
        for i in range(num_out_branches):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(
                        nn.Sequential(
                            build_conv_layer(self.conv_cfg,
                                             in_channels[j],
                                             in_channels[i],
                                             kernel_size=1,
                                             stride=1,
                                             padding=0,
                                             bias=False),
                            build_norm_layer(self.norm_cfg, in_channels[i])[1],
                            nn.Upsample(scale_factor=2**(j - i),
                                        mode='nearest')))
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv_downsamples = []
                    for k in range(i - j):
                        if k == i - j - 1:
                            conv_downsamples.append(
                                nn.Sequential(
                                    build_conv_layer(self.conv_cfg,
                                                     in_channels[j],
                                                     in_channels[i],
                                                     kernel_size=3,
                                                     stride=2,
                                                     padding=1,
                                                     bias=False),
                                    build_norm_layer(self.norm_cfg,
                                                     in_channels[i])[1]))
                        else:
                            conv_downsamples.append(
                                nn.Sequential(
                                    build_conv_layer(self.conv_cfg,
                                                     in_channels[j],
                                                     in_channels[j],
                                                     kernel_size=3,
                                                     stride=2,
                                                     padding=1,
                                                     bias=False),
                                    build_norm_layer(self.norm_cfg,
                                                     in_channels[j])[1],
                                    nn.ReLU(inplace=False)))
                    fuse_layer.append(nn.Sequential(*conv_downsamples))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def forward(self, x):
        """Forward function."""
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i] = self.branches[i](x[i])

        x_fuse = []
        for i in range(len(self.fuse_layers)):
            y = 0
            for j in range(self.num_branches):
                if i == j:
                    y += x[j]
                else:
                    y += self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))
        return x_fuse


class PoseHighResolutionNet(BaseModule):
    """HRNet backbone.
    `High-Resolution Representations for Labeling Pixels and Regions
    arXiv: <https://arxiv.org/abs/1904.04514>`_.
    Args:
        extra (dict): Detailed configuration for each stage of HRNet.
            There must be 4 stages, the configuration for each stage must have
            5 keys:
                - num_modules(int): The number of HRModule in this stage.
                - num_branches(int): The number of branches in the HRModule.
                - block(str): The type of convolution block.
                - num_blocks(tuple): The number of blocks in each branch.
                    The length must be equal to num_branches.
                - num_channels(tuple): The number of channels in each branch.
                    The length must be equal to num_branches.
        in_channels (int): Number of input image channels. Default: 3.
        conv_cfg (dict): Dictionary to construct and config conv layer.
        norm_cfg (dict): Dictionary to construct and config norm layer.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only. Default: True.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Default: False.
        zero_init_residual (bool): Whether to use zero init for last norm layer
            in resblocks to let them behave as identity. Default: False.
        multiscale_output (bool): Whether to output multi-level features
            produced by multiple branches. If False, only the first level
            feature will be output. Default: True.
        num_joints(int): the number of output for the final layer. Default: 24.
        pretrained (str, optional): Model pretrained path. Default: None.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.
    """

    blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}

    def __init__(self,
                 extra,
                 in_channels=3,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 norm_eval=True,
                 with_cp=False,
                 num_joints=24,
                 zero_init_residual=False,
                 multiscale_output=True,
                 pretrained=None,
                 init_cfg=None):
        super(PoseHighResolutionNet, self).__init__(init_cfg)

        self.pretrained = pretrained
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be specified at the same time'
        if isinstance(pretrained, str):
            warnings.warn('DeprecationWarning: pretrained is deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
        elif pretrained is None:
            if init_cfg is None:
                self.init_cfg = [
                    dict(type='Kaiming', layer='Conv2d'),
                    dict(type='Constant',
                         val=1,
                         layer=['_BatchNorm', 'GroupNorm'])
                ]
        else:
            raise TypeError('pretrained must be a str or None')

        # Assert configurations of 4 stages are in extra
        assert 'stage1' in extra and 'stage2' in extra \
               and 'stage3' in extra and 'stage4' in extra
        # Assert whether the length of `num_blocks` and `num_channels` are
        # equal to `num_branches`
        for i in range(4):
            cfg = extra[f'stage{i + 1}']
            assert len(cfg['num_blocks']) == cfg['num_branches'] and \
                   len(cfg['num_channels']) == cfg['num_branches']

        self.extra = extra
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.norm_eval = norm_eval
        self.with_cp = with_cp
        self.zero_init_residual = zero_init_residual

        # stem net
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)

        self.conv1 = build_conv_layer(self.conv_cfg,
                                      in_channels,
                                      64,
                                      kernel_size=3,
                                      stride=2,
                                      padding=1,
                                      bias=False)

        self.add_module(self.norm1_name, norm1)
        self.conv2 = build_conv_layer(self.conv_cfg,
                                      64,
                                      64,
                                      kernel_size=3,
                                      stride=2,
                                      padding=1,
                                      bias=False)

        self.add_module(self.norm2_name, norm2)
        self.relu = nn.ReLU(inplace=True)

        # stage 1
        self.stage1_cfg = self.extra['stage1']
        num_channels = self.stage1_cfg['num_channels'][0]
        block_type = self.stage1_cfg['block']
        num_blocks = self.stage1_cfg['num_blocks'][0]

        block = self.blocks_dict[block_type]
        stage1_out_channels = num_channels * block.expansion
        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)

        # stage 2
        self.stage2_cfg = self.extra['stage2']
        num_channels = self.stage2_cfg['num_channels']
        block_type = self.stage2_cfg['block']

        block = self.blocks_dict[block_type]
        num_channels = [channel * block.expansion for channel in num_channels]
        self.transition1 = self._make_transition_layer([stage1_out_channels],
                                                       num_channels)
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels)

        # stage 3
        self.stage3_cfg = self.extra['stage3']
        num_channels = self.stage3_cfg['num_channels']
        block_type = self.stage3_cfg['block']

        block = self.blocks_dict[block_type]
        num_channels = [channel * block.expansion for channel in num_channels]
        self.transition2 = self._make_transition_layer(pre_stage_channels,
                                                       num_channels)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels)

        # stage 4
        self.stage4_cfg = self.extra['stage4']
        num_channels = self.stage4_cfg['num_channels']
        block_type = self.stage4_cfg['block']

        block = self.blocks_dict[block_type]
        num_channels = [channel * block.expansion for channel in num_channels]
        self.transition3 = self._make_transition_layer(pre_stage_channels,
                                                       num_channels)
        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
        # self.pretrained_layers = extra['pretrained_layers']
        self.final_layer = build_conv_layer(
            cfg=self.conv_cfg,
            in_channels=pre_stage_channels[0],
            out_channels=num_joints,
            kernel_size=extra['final_conv_kernel'],
            stride=1,
            padding=1 if extra['final_conv_kernel'] == 3 else 0)
        if extra['downsample'] and extra['use_conv']:
            self.downsample_stage_1 = self._make_downsample_layer(
                3, num_channel=self.stage2_cfg['num_channels'][0])
            self.downsample_stage_2 = self._make_downsample_layer(
                2, num_channel=self.stage2_cfg['num_channels'][-1])
            self.downsample_stage_3 = self._make_downsample_layer(
                1, num_channel=self.stage3_cfg['num_channels'][-1])
        elif not extra['downsample'] and extra['use_conv']:
            self.upsample_stage_2 = self._make_upsample_layer(
                1, num_channel=self.stage2_cfg['num_channels'][-1])
            self.upsample_stage_3 = self._make_upsample_layer(
                2, num_channel=self.stage3_cfg['num_channels'][-1])
            self.upsample_stage_4 = self._make_upsample_layer(
                3, num_channel=self.stage4_cfg['num_channels'][-1])

    @property
    def norm1(self):
        """nn.Module: the normalization layer named "norm1" """
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        """nn.Module: the normalization layer named "norm2" """
        return getattr(self, self.norm2_name)

    def _make_transition_layer(self, num_channels_pre_layer,
                               num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(
                        nn.Sequential(
                            build_conv_layer(self.conv_cfg,
                                             num_channels_pre_layer[i],
                                             num_channels_cur_layer[i],
                                             kernel_size=3,
                                             stride=1,
                                             padding=1,
                                             bias=False),
                            build_norm_layer(self.norm_cfg,
                                             num_channels_cur_layer[i])[1],
                            nn.ReLU(inplace=True)))
                else:
                    transition_layers.append(None)
            else:
                conv_downsamples = []
                for j in range(i + 1 - num_branches_pre):
                    in_channels = num_channels_pre_layer[-1]
                    out_channels = num_channels_cur_layer[i] \
                        if j == i - num_branches_pre else in_channels
                    conv_downsamples.append(
                        nn.Sequential(
                            build_conv_layer(self.conv_cfg,
                                             in_channels,
                                             out_channels,
                                             kernel_size=3,
                                             stride=2,
                                             padding=1,
                                             bias=False),
                            build_norm_layer(self.norm_cfg, out_channels)[1],
                            nn.ReLU(inplace=True)))
                transition_layers.append(nn.Sequential(*conv_downsamples))

        return nn.ModuleList(transition_layers)

    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                build_conv_layer(self.conv_cfg,
                                 inplanes,
                                 planes * block.expansion,
                                 kernel_size=1,
                                 stride=stride,
                                 bias=False),
                build_norm_layer(self.norm_cfg, planes * block.expansion)[1])

        layers = []
        block_init_cfg = None
        if self.pretrained is None and not hasattr(
                self, 'init_cfg') and self.zero_init_residual:
            if block is BasicBlock:
                block_init_cfg = dict(type='Constant',
                                      val=0,
                                      override=dict(name='norm2'))
            elif block is Bottleneck:
                block_init_cfg = dict(type='Constant',
                                      val=0,
                                      override=dict(name='norm3'))
        layers.append(
            block(
                inplanes,
                planes,
                stride,
                downsample=downsample,
                with_cp=self.with_cp,
                norm_cfg=self.norm_cfg,
                conv_cfg=self.conv_cfg,
                init_cfg=block_init_cfg,
            ))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(
                block(inplanes,
                      planes,
                      with_cp=self.with_cp,
                      norm_cfg=self.norm_cfg,
                      conv_cfg=self.conv_cfg,
                      init_cfg=block_init_cfg))

        return Sequential(*layers)

    def _make_stage(self, layer_config, in_channels, multiscale_output=True):
        num_modules = layer_config['num_modules']
        num_branches = layer_config['num_branches']
        num_blocks = layer_config['num_blocks']
        num_channels = layer_config['num_channels']
        block = self.blocks_dict[layer_config['block']]

        hr_modules = []
        block_init_cfg = None
        if self.pretrained is None and not hasattr(
                self, 'init_cfg') and self.zero_init_residual:
            if block is BasicBlock:
                block_init_cfg = dict(type='Constant',
                                      val=0,
                                      override=dict(name='norm2'))
            elif block is Bottleneck:
                block_init_cfg = dict(type='Constant',
                                      val=0,
                                      override=dict(name='norm3'))

        for i in range(num_modules):
            # multi_scale_output is only used for the last module
            if not multiscale_output and i == num_modules - 1:
                reset_multiscale_output = False
            else:
                reset_multiscale_output = True

            hr_modules.append(
                HRModule(num_branches,
                         block,
                         num_blocks,
                         in_channels,
                         num_channels,
                         reset_multiscale_output,
                         with_cp=self.with_cp,
                         norm_cfg=self.norm_cfg,
                         conv_cfg=self.conv_cfg,
                         block_init_cfg=block_init_cfg))

        return Sequential(*hr_modules), in_channels

    def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
        layers = []
        for i in range(num_layers):
            layers.append(
                nn.Upsample(scale_factor=2,
                            mode='bilinear',
                            align_corners=True))
            layers.append(
                build_conv_layer(
                    cfg=self.conv_cfg,
                    in_channels=num_channel,
                    out_channels=num_channel,
                    kernel_size=kernel_size,
                    stride=1,
                    padding=1,
                    bias=False,
                ))
            layers.append(build_norm_layer(self.norm_cfg, num_channel)[1])
            layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*layers)

    def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
        layers = []
        for i in range(num_layers):
            layers.append(
                build_conv_layer(
                    cfg=self.conv_cfg,
                    in_channels=num_channel,
                    out_channels=num_channel,
                    kernel_size=kernel_size,
                    stride=2,
                    padding=1,
                    bias=False,
                ))
            layers.append(build_norm_layer(self.norm_cfg, num_channel)[1])
            layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*layers)

    def forward(self, x):
        """Forward function."""
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)
        x = self.layer1(x)

        x_list = []
        for i in range(self.stage2_cfg['num_branches']):
            if self.transition1[i] is not None:
                x_list.append(self.transition1[i](x))
            else:
                x_list.append(x)
        y_list = self.stage2(x_list)

        x_list = []
        for i in range(self.stage3_cfg['num_branches']):
            if self.transition2[i] is not None:
                x_list.append(self.transition2[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        y_list = self.stage3(x_list)

        x_list = []
        for i in range(self.stage4_cfg['num_branches']):
            if self.transition3[i] is not None:
                x_list.append(self.transition3[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        y_list = self.stage4(x_list)
        if self.extra['return_list']:
            return y_list
        elif self.extra['downsample']:
            if self.extra['use_conv']:
                # Downsampling with strided convolutions
                x1 = self.downsample_stage_1(y_list[0])
                x2 = self.downsample_stage_2(y_list[1])
                x3 = self.downsample_stage_3(y_list[2])
                x = torch.cat([x1, x2, x3, y_list[3]], 1)
            else:
                # Downsampling with interpolation
                x0_h, x0_w = y_list[3].size(2), y_list[3].size(3)
                x1 = F.interpolate(y_list[0],
                                   size=(x0_h, x0_w),
                                   mode='bilinear',
                                   align_corners=True)
                x2 = F.interpolate(y_list[1],
                                   size=(x0_h, x0_w),
                                   mode='bilinear',
                                   align_corners=True)
                x3 = F.interpolate(y_list[2],
                                   size=(x0_h, x0_w),
                                   mode='bilinear',
                                   align_corners=True)
                x = torch.cat([x1, x2, x3, y_list[3]], 1)
        else:
            if self.extra['use_conv']:
                # Upsampling with interpolations + convolutions
                x1 = self.upsample_stage_2(y_list[1])
                x2 = self.upsample_stage_3(y_list[2])
                x3 = self.upsample_stage_4(y_list[3])
                x = torch.cat([y_list[0], x1, x2, x3], 1)
            else:
                # Upsampling with interpolation
                x0_h, x0_w = y_list[0].size(2), y_list[0].size(3)
                x1 = F.interpolate(y_list[1],
                                   size=(x0_h, x0_w),
                                   mode='bilinear',
                                   align_corners=True)
                x2 = F.interpolate(y_list[2],
                                   size=(x0_h, x0_w),
                                   mode='bilinear',
                                   align_corners=True)
                x3 = F.interpolate(y_list[3],
                                   size=(x0_h, x0_w),
                                   mode='bilinear',
                                   align_corners=True)
                x = torch.cat([y_list[0], x1, x2, x3], 1)
        return x

    def train(self, mode=True):
        """Convert the model into training mode will keeping the normalization
        layer freezed."""
        super(PoseHighResolutionNet, self).train(mode)
        if mode and self.norm_eval:
            for m in self.modules():
                # trick: eval have effect on BatchNorm only
                if isinstance(m, _BatchNorm):
                    m.eval()


class PoseHighResolutionNetExpose(PoseHighResolutionNet):
    """HRNet backbone for expose."""
    def __init__(self,
                 extra,
                 in_channels=3,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 norm_eval=True,
                 with_cp=False,
                 num_joints=24,
                 zero_init_residual=False,
                 multiscale_output=True,
                 pretrained=None,
                 init_cfg=None):
        super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval,
                         with_cp, num_joints, zero_init_residual,
                         multiscale_output, pretrained, init_cfg)
        in_dims = (2**2 * self.stage2_cfg['num_channels'][-1] +
                   2**1 * self.stage3_cfg['num_channels'][-1] +
                   self.stage4_cfg['num_channels'][-1])
        self.conv_layers = self._make_conv_layer(in_channels=in_dims,
                                                 num_layers=5)
        self.subsample_3 = self._make_subsample_layer(
            in_channels=self.stage2_cfg['num_channels'][-1], num_layers=2)
        self.subsample_2 = self._make_subsample_layer(
            in_channels=self.stage3_cfg['num_channels'][-1], num_layers=1)

    def _make_conv_layer(self,
                         in_channels=2048,
                         num_layers=3,
                         num_filters=2048,
                         stride=1):

        layers = []
        for i in range(num_layers):

            downsample = nn.Conv2d(in_channels,
                                   num_filters,
                                   stride=1,
                                   kernel_size=1,
                                   bias=False)
            layers.append(
                Bottleneck(in_channels,
                           num_filters // 4,
                           downsample=downsample))
            in_channels = num_filters

        return nn.Sequential(*layers)

    def _make_subsample_layer(self, in_channels=96, num_layers=3, stride=2):

        layers = []
        for i in range(num_layers):

            layers.append(
                nn.Conv2d(in_channels=in_channels,
                          out_channels=2 * in_channels,
                          kernel_size=3,
                          stride=stride,
                          padding=1))
            in_channels = 2 * in_channels
            layers.append(nn.BatchNorm2d(in_channels, momentum=0.1))
            layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*layers)

    def forward(self, x):
        """Forward function."""
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)
        x = self.layer1(x)

        x_list = []
        for i in range(self.stage2_cfg['num_branches']):
            if self.transition1[i] is not None:
                x_list.append(self.transition1[i](x))
            else:
                x_list.append(x)
        y_list = self.stage2(x_list)

        x_list = []
        for i in range(self.stage3_cfg['num_branches']):
            if self.transition2[i] is not None:
                x_list.append(self.transition2[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        y_list = self.stage3(x_list)

        x_list = []
        for i in range(self.stage4_cfg['num_branches']):
            if self.transition3[i] is not None:
                x_list.append(self.transition3[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        x3 = self.subsample_3(x_list[1])
        x2 = self.subsample_2(x_list[2])
        x1 = x_list[3]
        xf = self.conv_layers(torch.cat([x3, x2, x1], dim=1))
        xf = xf.mean(dim=(2, 3))
        xf = xf.view(xf.size(0), -1)
        return xf