File size: 6,591 Bytes
c5e5450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import clip
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R

import os
import sys
utils_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'utils'))
sys.path.append(utils_dir)
from utils.transforms import rigid_transform_3D, transform_points_numpy
from utils.constants import rest_pelvis


def test_model(models, diffuser, normalizer, configs, text_embedder, hint_text, prog_ind, joint_orig=None, All_one_model=True, **kwargs):
    # set up
    if All_one_model:
        model= models['model']
        try:
            disc_model = models['disc_model']
        except:
            print("disc_model is not provided!", flush=True)
            disc_model = None
    else:
        assert len(kwargs['model_type']) == len(hint_text), "model_type should have the same length as hint_text"

    device = joint_orig.device
    normalize, denormalize = normalizer
    text_embedder = text_embedder
    batch_size = configs['batch_size']
    seq_len = configs['seq_len']
    channels = configs['channels']
    fixed_frame = configs['fixed_frame']
    use_cfg = configs['use_cfg']
    cfg_alpha = configs['cfg_alpha']

    # for classifier guidance
    cg_alpha = configs['cg_alpha']
    cg_diffusion_steps = configs['cg_diffusion_steps']
    
    # select the prog_ind and hint embedding
    def get_prog_hint(i, prog_ind, hint_emb, model_type=None):
        get_hint_idx = i
        remains = 0
        task_i = None
        for j in range(len(prog_ind)+1):
            if(get_hint_idx>=0):
                get_hint_idx -= len(prog_ind[j])
            else:
                remains = get_hint_idx + len(prog_ind[j-1])
                get_hint_idx = j-1
                break
        prog_ind_i = torch.tensor(prog_ind[get_hint_idx][remains]).unsqueeze(0).to(device)
        if model_type is not None:
            task_i = model_type[get_hint_idx][remains]
        else:
            task_i = None
        hint_emb_i = hint_emb[get_hint_idx].unsqueeze(0)
        return prog_ind_i, hint_emb_i, task_i
    

    epochs_num = 0
    begining_frame = joint_orig[0,:fixed_frame,...].reshape(-1, fixed_frame, channels)
    samples_total = [] 
    orig_samples_total = []

    if hint_text:
        hint_token = clip.tokenize(hint_text).to(device)
        hint_emb = text_embedder.encode_text(hint_token).to(device=device, dtype=torch.float32)
        for i in range(len(prog_ind)):
            epochs_num += len(prog_ind[i])

    ################################################################################
    # autogregresive diffusion
    trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
    trans_mats_orig = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
    
    for i in range(epochs_num):
        if All_one_model:
            prog_ind_i, hint_emb_i, _ = get_prog_hint(i, prog_ind, hint_emb)
        else:
            prog_ind_i, hint_emb_i, task_model = get_prog_hint(i, prog_ind, hint_emb, kwargs['model_type'])
        joint_orig_i = joint_orig[i].reshape(-1, seq_len, channels)
        
        if not All_one_model:
            model = models[task_model]
            disc_model = models[task_model+'_disc']
        samples = diffuser.sample(model, 
                                batch_size=batch_size, 
                                seq_len=seq_len, 
                                channels=channels,
                                fixed_points=begining_frame, 
                                text=hint_emb_i, 
                                prog_ind=prog_ind_i,
                                joints_orig=joint_orig_i,
                                use_cfg=use_cfg,
                                cfg_alpha=cfg_alpha,
                                disc_model=disc_model,
                                cg_alpha = cg_alpha,
                                cg_diffusion_steps = cg_diffusion_steps,
                                )   
        
        samples = samples[-1] # only consider the last timestep
        samples = denormalize(samples)
        samples = samples.detach().cpu().numpy()
        # for original motion
        orig_samples = denormalize(joint_orig_i).detach().cpu().numpy()
    
        if i==0:
            samples_total.append(samples)
            orig_samples_total.append(orig_samples)
        else:
            samples = samples[:, fixed_frame:, :]
            samples = transform_points_numpy(samples, trans_mats)
            samples_total.append(samples)

            orig_samples = orig_samples[:, fixed_frame:, :]
            orig_samples = transform_points_numpy(orig_samples, trans_mats_orig)
            orig_samples_total.append(orig_samples)
        

        begining_frame = samples[:, -fixed_frame:, :]
        pelvis_new = begining_frame[:, -fixed_frame, :9].reshape(batch_size, 3, 3)
        trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
        for ip, pn in enumerate(pelvis_new):
            _, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False)
            ret_t[1] = 0.0
            rot_euler = R.from_matrix(ret_R).as_euler('zxy')
            shift_euler = np.array([0, 0, rot_euler[2]])
            shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix()
            trans_mats[ip, :3, :3] = shift_rot_matrix2
            trans_mats[ip, :3, 3] = ret_t.reshape(-1)
        begining_frame = normalize(torch.tensor(transform_points_numpy(begining_frame, np.linalg.inv(trans_mats)), device=device, dtype=torch.float32))
        
        begining_frame_orig = orig_samples[:, -fixed_frame:, :]
        pelvis_new_orig = begining_frame_orig[:, -fixed_frame, :9].reshape(batch_size, 3, 3)
        trans_mats_orig = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
        for ip, pn in enumerate(pelvis_new_orig):
            _, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False)
            ret_t[1] = 0.0
            rot_euler = R.from_matrix(ret_R).as_euler('zxy')
            shift_euler = np.array([0, 0, rot_euler[2]])
            shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix()
            trans_mats_orig[ip, :3, :3] = shift_rot_matrix2
            trans_mats_orig[ip, :3, 3] = ret_t.reshape(-1)
        begining_frame_orig = normalize(torch.tensor(transform_points_numpy(begining_frame_orig, np.linalg.inv(trans_mats_orig)), device=device, dtype=torch.float32))

    samples_total = np.concatenate(samples_total, axis=1) 
    orig_samples_total = np.concatenate(orig_samples_total, axis=1)

    
    return samples_total, orig_samples_total