MotionCLR / motion_loader /model_motion_loaders.py
EvanTHU's picture
init demo
b887ad8 verified
import torch
from utils.word_vectorizer import WordVectorizer
from torch.utils.data import Dataset, DataLoader
from os.path import join as pjoin
from tqdm import tqdm
import numpy as np
from eval.evaluator_modules import *
from torch.utils.data._utils.collate import default_collate
class GeneratedDataset(Dataset):
"""
opt.dataset_name
opt.max_motion_length
opt.unit_length
"""
def __init__(
self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats
):
assert mm_num_samples < len(dataset)
self.dataset = dataset
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
generated_motion = []
min_mov_length = 10 if opt.dataset_name == "t2m" else 6
# Pre-process all target captions
mm_generated_motions = []
if mm_num_samples > 0:
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
mm_idxs = np.sort(mm_idxs)
all_caption = []
all_m_lens = []
all_data = []
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader)):
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
all_data.append(data)
tokens = tokens[0].split("_")
mm_num_now = len(mm_generated_motions)
is_mm = (
True
if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now]))
else False
)
repeat_times = mm_num_repeats if is_mm else 1
m_lens = max(
torch.div(m_lens, opt.unit_length, rounding_mode="trunc")
* opt.unit_length,
min_mov_length * opt.unit_length,
)
m_lens = min(m_lens, opt.max_motion_length)
if isinstance(m_lens, int):
m_lens = torch.LongTensor([m_lens]).to(opt.device)
else:
m_lens = m_lens.to(opt.device)
for t in range(repeat_times):
all_m_lens.append(m_lens)
all_caption.extend(caption)
if is_mm:
mm_generated_motions.append(0)
all_m_lens = torch.stack(all_m_lens)
# Generate all sequences
with torch.no_grad():
all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens)
self.eval_generate_time = t_eval
cur_idx = 0
mm_generated_motions = []
with torch.no_grad():
for i, data_dummy in tqdm(enumerate(dataloader)):
data = all_data[i]
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
tokens = tokens[0].split("_")
mm_num_now = len(mm_generated_motions)
is_mm = (
True
if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now]))
else False
)
repeat_times = mm_num_repeats if is_mm else 1
mm_motions = []
for t in range(repeat_times):
pred_motions = all_pred_motions[cur_idx]
cur_idx += 1
if t == 0:
sub_dict = {
"motion": pred_motions.cpu().numpy(),
"length": pred_motions.shape[0], # m_lens[0].item(), #
"caption": caption[0],
"cap_len": cap_lens[0].item(),
"tokens": tokens,
}
generated_motion.append(sub_dict)
if is_mm:
mm_motions.append(
{
"motion": pred_motions.cpu().numpy(),
"length": pred_motions.shape[
0
], # m_lens[0].item(), #m_lens[0].item()
}
)
if is_mm:
mm_generated_motions.append(
{
"caption": caption[0],
"tokens": tokens,
"cap_len": cap_lens[0].item(),
"mm_motions": mm_motions,
}
)
self.generated_motion = generated_motion
self.mm_generated_motion = mm_generated_motions
self.opt = opt
self.w_vectorizer = w_vectorizer
def __len__(self):
return len(self.generated_motion)
def __getitem__(self, item):
data = self.generated_motion[item]
motion, m_length, caption, tokens = (
data["motion"],
data["length"],
data["caption"],
data["tokens"],
)
sent_len = data["cap_len"]
# This step is needed because T2M evaluators expect their norm convention
normed_motion = motion
denormed_motion = self.dataset.inv_transform(normed_motion)
renormed_motion = (
denormed_motion - self.dataset.mean_for_eval
) / self.dataset.std_for_eval # according to T2M norms
motion = renormed_motion
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
length = len(motion)
if length < self.opt.max_motion_length:
motion = np.concatenate(
[
motion,
np.zeros((self.opt.max_motion_length - length, motion.shape[1])),
],
axis=0,
)
return (
word_embeddings,
pos_one_hots,
caption,
sent_len,
motion,
m_length,
"_".join(tokens),
)
def collate_fn(batch):
batch.sort(key=lambda x: x[3], reverse=True)
return default_collate(batch)
class MMGeneratedDataset(Dataset):
def __init__(self, opt, motion_dataset, w_vectorizer):
self.opt = opt
self.dataset = motion_dataset.mm_generated_motion
self.w_vectorizer = w_vectorizer
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
data = self.dataset[item]
mm_motions = data["mm_motions"]
m_lens = []
motions = []
for mm_motion in mm_motions:
m_lens.append(mm_motion["length"])
motion = mm_motion["motion"]
if len(motion) < self.opt.max_motion_length:
motion = np.concatenate(
[
motion,
np.zeros(
(self.opt.max_motion_length - len(motion), motion.shape[1])
),
],
axis=0,
)
motion = motion[None, :]
motions.append(motion)
m_lens = np.array(m_lens, dtype=np.int32)
motions = np.concatenate(motions, axis=0)
sort_indx = np.argsort(m_lens)[::-1].copy()
m_lens = m_lens[sort_indx]
motions = motions[sort_indx]
return motions, m_lens
def get_motion_loader(
opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats
):
# Currently the configurations of two datasets are almost the same
if opt.dataset_name == "t2m" or opt.dataset_name == "kit":
w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab")
else:
raise KeyError("Dataset not recognized!!")
dataset = GeneratedDataset(
opt,
pipeline,
ground_truth_dataset,
w_vectorizer,
mm_num_samples,
mm_num_repeats,
)
mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer)
motion_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
drop_last=True,
num_workers=4,
)
mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
return motion_loader, mm_motion_loader, dataset.eval_generate_time