# 2022.07.19 - Changed for CLIFF # Huawei Technologies Co., Ltd. # Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. # Copyright (c) 2019, University of Pennsylvania, Max Planck Institute for Intelligent Systems # This program is free software; you can redistribute it and/or modify it # under the terms of the MIT license. # This program is distributed in the hope that it will be useful, but WITHOUT ANY # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A # PARTICULAR PURPOSE. See the MIT License for more details. # This script is borrowed and extended from SPIN import torch import torch.nn as nn import numpy as np import math import os.path as osp from common.imutils import rot6d_to_rotmat from models.backbones.hrnet.cls_hrnet import HighResolutionNet from models.backbones.hrnet.hrnet_config import cfg from models.backbones.hrnet.hrnet_config import update_config class CLIFF(nn.Module): """ SMPL Iterative Regressor with ResNet50 backbone""" def __init__(self, smpl_mean_params, img_feat_num=2048): super(CLIFF, self).__init__() curr_dir = osp.dirname(osp.abspath(__file__)) config_file = osp.join(curr_dir, "../backbones/hrnet/models/cls_hrnet_w48_sgd_lr5e-2_wd1e-4_bs32_x100.yaml") update_config(cfg, config_file) self.encoder = HighResolutionNet(cfg) npose = 24 * 6 nshape = 10 ncam = 3 nbbox = 3 fc1_feat_num = 1024 fc2_feat_num = 1024 final_feat_num = fc2_feat_num reg_in_feat_num = img_feat_num + nbbox + npose + nshape + ncam # CUDA Error: an illegal memory access was encountered # the above error will occur, if use mobilenet v3 with BN, so don't use BN self.fc1 = nn.Linear(reg_in_feat_num, fc1_feat_num) self.drop1 = nn.Dropout() self.fc2 = nn.Linear(fc1_feat_num, fc2_feat_num) self.drop2 = nn.Dropout() self.decpose = nn.Linear(final_feat_num, npose) self.decshape = nn.Linear(final_feat_num, nshape) self.deccam = nn.Linear(final_feat_num, ncam) nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() mean_params = np.load(smpl_mean_params) init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) self.register_buffer('init_pose', init_pose) self.register_buffer('init_shape', init_shape) self.register_buffer('init_cam', init_cam) def forward(self, x, bbox, init_pose=None, init_shape=None, init_cam=None, n_iter=3): batch_size = x.shape[0] if init_pose is None: init_pose = self.init_pose.expand(batch_size, -1) if init_shape is None: init_shape = self.init_shape.expand(batch_size, -1) if init_cam is None: init_cam = self.init_cam.expand(batch_size, -1) xf = self.encoder(x) pred_pose = init_pose pred_shape = init_shape pred_cam = init_cam for i in range(n_iter): xc = torch.cat([xf, bbox, pred_pose, pred_shape, pred_cam], 1) xc = self.fc1(xc) xc = self.drop1(xc) xc = self.fc2(xc) xc = self.drop2(xc) pred_pose = self.decpose(xc) + pred_pose pred_shape = self.decshape(xc) + pred_shape pred_cam = self.deccam(xc) + pred_cam pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) return pred_rotmat, pred_shape, pred_cam