motionReFit / src /inference /inference.py
Yzy00518's picture
Upload src/inference/inference.py with huggingface_hub
c5e5450
import clip
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
import os
import sys
utils_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'utils'))
sys.path.append(utils_dir)
from utils.transforms import rigid_transform_3D, transform_points_numpy
from utils.constants import rest_pelvis
def test_model(models, diffuser, normalizer, configs, text_embedder, hint_text, prog_ind, joint_orig=None, All_one_model=True, **kwargs):
# set up
if All_one_model:
model= models['model']
try:
disc_model = models['disc_model']
except:
print("disc_model is not provided!", flush=True)
disc_model = None
else:
assert len(kwargs['model_type']) == len(hint_text), "model_type should have the same length as hint_text"
device = joint_orig.device
normalize, denormalize = normalizer
text_embedder = text_embedder
batch_size = configs['batch_size']
seq_len = configs['seq_len']
channels = configs['channels']
fixed_frame = configs['fixed_frame']
use_cfg = configs['use_cfg']
cfg_alpha = configs['cfg_alpha']
# for classifier guidance
cg_alpha = configs['cg_alpha']
cg_diffusion_steps = configs['cg_diffusion_steps']
# select the prog_ind and hint embedding
def get_prog_hint(i, prog_ind, hint_emb, model_type=None):
get_hint_idx = i
remains = 0
task_i = None
for j in range(len(prog_ind)+1):
if(get_hint_idx>=0):
get_hint_idx -= len(prog_ind[j])
else:
remains = get_hint_idx + len(prog_ind[j-1])
get_hint_idx = j-1
break
prog_ind_i = torch.tensor(prog_ind[get_hint_idx][remains]).unsqueeze(0).to(device)
if model_type is not None:
task_i = model_type[get_hint_idx][remains]
else:
task_i = None
hint_emb_i = hint_emb[get_hint_idx].unsqueeze(0)
return prog_ind_i, hint_emb_i, task_i
epochs_num = 0
begining_frame = joint_orig[0,:fixed_frame,...].reshape(-1, fixed_frame, channels)
samples_total = []
orig_samples_total = []
if hint_text:
hint_token = clip.tokenize(hint_text).to(device)
hint_emb = text_embedder.encode_text(hint_token).to(device=device, dtype=torch.float32)
for i in range(len(prog_ind)):
epochs_num += len(prog_ind[i])
################################################################################
# autogregresive diffusion
trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
trans_mats_orig = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
for i in range(epochs_num):
if All_one_model:
prog_ind_i, hint_emb_i, _ = get_prog_hint(i, prog_ind, hint_emb)
else:
prog_ind_i, hint_emb_i, task_model = get_prog_hint(i, prog_ind, hint_emb, kwargs['model_type'])
joint_orig_i = joint_orig[i].reshape(-1, seq_len, channels)
if not All_one_model:
model = models[task_model]
disc_model = models[task_model+'_disc']
samples = diffuser.sample(model,
batch_size=batch_size,
seq_len=seq_len,
channels=channels,
fixed_points=begining_frame,
text=hint_emb_i,
prog_ind=prog_ind_i,
joints_orig=joint_orig_i,
use_cfg=use_cfg,
cfg_alpha=cfg_alpha,
disc_model=disc_model,
cg_alpha = cg_alpha,
cg_diffusion_steps = cg_diffusion_steps,
)
samples = samples[-1] # only consider the last timestep
samples = denormalize(samples)
samples = samples.detach().cpu().numpy()
# for original motion
orig_samples = denormalize(joint_orig_i).detach().cpu().numpy()
if i==0:
samples_total.append(samples)
orig_samples_total.append(orig_samples)
else:
samples = samples[:, fixed_frame:, :]
samples = transform_points_numpy(samples, trans_mats)
samples_total.append(samples)
orig_samples = orig_samples[:, fixed_frame:, :]
orig_samples = transform_points_numpy(orig_samples, trans_mats_orig)
orig_samples_total.append(orig_samples)
begining_frame = samples[:, -fixed_frame:, :]
pelvis_new = begining_frame[:, -fixed_frame, :9].reshape(batch_size, 3, 3)
trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
for ip, pn in enumerate(pelvis_new):
_, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False)
ret_t[1] = 0.0
rot_euler = R.from_matrix(ret_R).as_euler('zxy')
shift_euler = np.array([0, 0, rot_euler[2]])
shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix()
trans_mats[ip, :3, :3] = shift_rot_matrix2
trans_mats[ip, :3, 3] = ret_t.reshape(-1)
begining_frame = normalize(torch.tensor(transform_points_numpy(begining_frame, np.linalg.inv(trans_mats)), device=device, dtype=torch.float32))
begining_frame_orig = orig_samples[:, -fixed_frame:, :]
pelvis_new_orig = begining_frame_orig[:, -fixed_frame, :9].reshape(batch_size, 3, 3)
trans_mats_orig = np.repeat(np.eye(4)[np.newaxis, :, :], batch_size, axis=0)
for ip, pn in enumerate(pelvis_new_orig):
_, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False)
ret_t[1] = 0.0
rot_euler = R.from_matrix(ret_R).as_euler('zxy')
shift_euler = np.array([0, 0, rot_euler[2]])
shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix()
trans_mats_orig[ip, :3, :3] = shift_rot_matrix2
trans_mats_orig[ip, :3, 3] = ret_t.reshape(-1)
begining_frame_orig = normalize(torch.tensor(transform_points_numpy(begining_frame_orig, np.linalg.inv(trans_mats_orig)), device=device, dtype=torch.float32))
samples_total = np.concatenate(samples_total, axis=1)
orig_samples_total = np.concatenate(orig_samples_total, axis=1)
return samples_total, orig_samples_total