File size: 8,847 Bytes
b807ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import torch
import torch.nn as nn
from common import constants
from models.smpl import SMPL
from smplx import SMPLX
import pickle as pkl
import numpy as np
from utils.mesh_utils import save_results_mesh
from utils.diff_renderer import Pytorch3D
import os
import cv2


class sem_loss_function(nn.Module):
    def __init__(self):
        super(sem_loss_function, self).__init__()
        self.ce = nn.BCELoss()

    def forward(self, y_true, y_pred):
        loss = self.ce(y_pred, y_true)
        return loss


class class_loss_function(nn.Module):
    def __init__(self):
        super(class_loss_function, self).__init__()
        self.ce_loss = nn.BCELoss()
        # self.ce_loss = nn.MultiLabelSoftMarginLoss()
        # self.ce_loss = nn.MultiLabelMarginLoss()

    def forward(self, y_true, y_pred, valid_mask):
        # y_true = torch.squeeze(y_true, 1).long()
        # y_true = torch.squeeze(y_true, 1)
        # y_pred = torch.squeeze(y_pred, 1)
        bs = y_true.shape[0]
        if bs != 1:
            y_pred = y_pred[valid_mask == 1]
            y_true = y_true[valid_mask == 1]
        if len(y_pred) > 0:
            return self.ce_loss(y_pred, y_true)
        else:
            return torch.tensor(0.0).to(y_pred.device)


class pixel_anchoring_function(nn.Module):
    def __init__(self, model_type, device='cuda'):
        super(pixel_anchoring_function, self).__init__()

        self.device = device

        self.model_type = model_type

        if self.model_type == 'smplx':
            # load mapping from smpl vertices to smplx vertices
            mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl")
            with open(mapping_pkl, 'rb') as f:
                smpl_to_smplx_mapping = pkl.load(f)
                smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"]
            self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device)


        # Setup the SMPL model
        if self.model_type == 'smpl':
            self.n_vertices = 6890
            self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
        if self.model_type == 'smplx':
            self.n_vertices = 10475
            self.body_model = SMPLX(constants.SMPLX_MODEL_DIR,
                                    num_betas=10,
                                    use_pca=False).to(self.device)
        self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device)

        self.ce_loss = nn.BCELoss()

    def get_posed_mesh(self, body_params, debug=False):
        betas = body_params['betas']
        pose = body_params['pose']
        transl = body_params['transl']

        # extra smplx params
        extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
                      'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
                      'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
                      'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
                      'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
                      'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}

        smpl_output = self.body_model(betas=betas,
                                      body_pose=pose[:, 3:],
                                      global_orient=pose[:, :3],
                                      pose2rot=True,
                                      transl=transl,
                                      **extra_args)
        smpl_verts = smpl_output.vertices
        smpl_joints = smpl_output.joints

        if debug:
            for mesh_i in range(smpl_verts.shape[0]):
                out_dir = 'temp_meshes'
                os.makedirs(out_dir, exist_ok=True)
                out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
                save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
        return smpl_verts, smpl_joints


    def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False):

        bs = smpl_verts.shape[0]

        # Incorporate resizing factor into the camera
        img_w = 256 # TODO: Remove hardcoding
        img_h = 256 # TODO: Remove hardcoding
        focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0]
        focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1]
        # convert to float for pytorch3d
        focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float()

        # concatenate focal length
        focal_length = torch.stack([focal_length_x, focal_length_y], dim=1)

        # Setup renderer
        renderer = Pytorch3D(img_h=img_h,
                                  img_w=img_w,
                                  focal_length=focal_length,
                                  smpl_faces=self.body_faces,
                                  texture_mode='deco',
                                  vertex_colors=vertex_colors,
                                  face_textures=face_textures,
                                  is_train=True,
                                  is_cam_batch=True)
        front_view = renderer(smpl_verts)
        if debug:
            # visualize the front view as images in a temp_image folder
            for i in range(bs):
                front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu()
                front_view_mask = front_view[i, 3, :, :].detach().cpu()
                out_dir = 'temp_images'
                os.makedirs(out_dir, exist_ok=True)
                out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png')
                out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png')
                cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255)
                cv2.imwrite(out_file_mask, front_view_mask.numpy()*255)

        return front_view

    def paint_contact(self, pred_contact):
        """
        Paints the contact vertices on the SMPL mesh

        Args:
            pred_contact: prbabilities of contact vertices

        Returns:
            pred_rgb: RGB colors for the contact vertices
        """
        bs = pred_contact.shape[0]

        # initialize black and while colors
        colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device)
        colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1)

        # add another dimension to the contact probabilities for inverse probabilities
        pred_contact = torch.unsqueeze(pred_contact, 2)
        pred_contact = torch.cat((1 - pred_contact, pred_contact), 2)

        # get pred_rgb colors
        pred_vert_rgb = torch.bmm(pred_contact, colors)
        pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color
        pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device)
        pred_face_texture[:, :, 0, 0, :] = pred_face_rgb
        return pred_vert_rgb, pred_face_texture

    def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask):
        """
        Takes predicted contact labels (probabilities), transfers them to the posed mesh and
        renders to the image. Loss is computed between the rendered contact and the ground truth
        polygons from HOT.

        Args:
            pred_contact: predicted contact labels (probabilities)
            body_params: SMPL parameters in camera coords
            cam_k: camera intrinsics
            gt_contact_polygon: ground truth polygons from HOT
        """
        # convert pred_contact to smplx
        bs = pred_contact.shape[0]
        if self.model_type == 'smplx':
            smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1)
            pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None])
            pred_contact = pred_contact.squeeze()

        # get the posed mesh
        smpl_verts, smpl_joints = self.get_posed_mesh(body_params)

        # paint the contact vertices on the mesh
        vertex_colors, face_textures = self.paint_contact(pred_contact)

        # render the mesh
        front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures)
        front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1)
        front_view_mask = front_view[:, 3, :, :]

        # compute segmentation loss between rendered contact mask and ground truth contact mask
        front_view_rgb = front_view_rgb[valid_mask == 1]
        gt_contact_polygon = gt_contact_polygon[valid_mask == 1]
        loss = self.ce_loss(front_view_rgb, gt_contact_polygon)
        return loss, front_view_rgb, front_view_mask