File size: 4,305 Bytes
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487ee6d
da48dbe
487ee6d
 
da48dbe
 
 
 
fb140f6
 
 
 
 
da48dbe
 
 
fb140f6
487ee6d
fb140f6
487ee6d
fb140f6
487ee6d
fb140f6
487ee6d
fb140f6
487ee6d
fb140f6
487ee6d
fb140f6
487ee6d
fb140f6
487ee6d
fb140f6
487ee6d
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
da48dbe
 
 
 
 
fb140f6
da48dbe
 
fb140f6
da48dbe
 
 
 
 
 
487ee6d
 
 
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
da48dbe
fb140f6
da48dbe
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

import pytorch_lightning as pl
import torch
from termcolor import colored

from ..dataset.mesh_util import *
from ..net.geometry import orthogonal


class Format:
    end = '\033[0m'
    start = '\033[4m'


def init_loss():

    losses = {
    # Cloth: chamfer distance
        "cloth": {"weight": 1e3, "value": 0.0},
    # Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2)
        "stiff": {"weight": 1e5, "value": 0.0},
    # Cloth: det(R) = 1
        "rigid": {"weight": 1e5, "value": 0.0},
    # Cloth: edge length
        "edge": {"weight": 0, "value": 0.0},
    # Cloth: normal consistency
        "nc": {"weight": 0, "value": 0.0},
    # Cloth: laplacian smoonth
        "lapla": {"weight": 1e2, "value": 0.0},
    # Body: Normal_pred - Normal_smpl
        "normal": {"weight": 1e0, "value": 0.0},
    # Body: Silhouette_pred - Silhouette_smpl
        "silhouette": {"weight": 1e0, "value": 0.0},
    # Joint: reprojected joints difference
        "joint": {"weight": 5e0, "value": 0.0},
    }

    return losses


class SubTrainer(pl.Trainer):
    def save_checkpoint(self, filepath, weights_only=False):
        """Save model/training states as a checkpoint file through state-dump and file-write.
        Args:
            filepath: write-target file's path
            weights_only: saving model weights only
        """
        _checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)

        del_keys = []
        for key in _checkpoint["state_dict"].keys():
            for ignore_key in ["normal_filter", "voxelization", "reconEngine"]:
                if ignore_key in key:
                    del_keys.append(key)
        for key in del_keys:
            del _checkpoint["state_dict"][key]

        pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)


def query_func(opt, netG, features, points, proj_matrix=None):
    """
        - points: size of (bz, N, 3)
        - proj_matrix: size of (bz, 4, 4)
    return: size of (bz, 1, N)
    """
    assert len(points) == 1
    samples = points.repeat(opt.num_views, 1, 1)
    samples = samples.permute(0, 2, 1)    # [bz, 3, N]

    # view specific query
    if proj_matrix is not None:
        samples = orthogonal(samples, proj_matrix)

    calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)

    preds = netG.query(
        features=features,
        points=samples,
        calibs=calib_tensor,
        regressor=netG.if_regressor,
    )

    if type(preds) is list:
        preds = preds[0]

    return preds


def query_func_IF(batch, netG, points):
    """
        - points: size of (bz, N, 3)
    return: size of (bz, 1, N)
    """

    batch["samples_geo"] = points
    batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)

    preds = netG(batch)

    return preds.unsqueeze(1)


def batch_mean(res, key):
    return torch.stack([
        x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res
    ]).mean()


def accumulate(outputs, rot_num, split):

    hparam_log_dict = {}

    metrics = outputs[0].keys()
    datasets = split.keys()

    for dataset in datasets:
        for metric in metrics:
            keyword = f"{dataset}/{metric}"
            if keyword not in hparam_log_dict.keys():
                hparam_log_dict[keyword] = 0
            for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num):
                hparam_log_dict[keyword] += outputs[idx][metric].item()
            hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num

    print(colored(hparam_log_dict, "green"))

    return hparam_log_dict