MotionGPT / mGPT /data /humanml /dataset_t2m_eval.py
bill-jiang's picture
Init
4409449
raw
history blame
No virus
2.98 kB
import random
import numpy as np
from .dataset_t2m import Text2MotionDataset
class Text2MotionDatasetEval(Text2MotionDataset):
def __init__(
self,
data_root,
split,
mean,
std,
w_vectorizer,
max_motion_length=196,
min_motion_length=40,
unit_length=4,
fps=20,
tmpFile=True,
tiny=False,
debug=False,
**kwargs,
):
super().__init__(data_root, split, mean, std, max_motion_length,
min_motion_length, unit_length, fps, tmpFile, tiny,
debug, **kwargs)
self.w_vectorizer = w_vectorizer
def __getitem__(self, item):
# Get text data
idx = self.pointer + item
data = self.data_dict[self.name_list[idx]]
motion, m_length, text_list = data["motion"], data["length"], data["text"]
all_captions = [
' '.join([token.split('/')[0] for token in text_dic['tokens']])
for text_dic in text_list
]
if len(all_captions) > 3:
all_captions = all_captions[:3]
elif len(all_captions) == 2:
all_captions = all_captions + all_captions[0:1]
elif len(all_captions) == 1:
all_captions = all_captions * 3
# Randomly select a caption
text_data = random.choice(text_list)
caption, tokens = text_data["caption"], text_data["tokens"]
# Text
max_text_len = 20
if len(tokens) < max_text_len:
# pad with "unk"
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
sent_len = len(tokens)
tokens = tokens + ["unk/OTHER"] * (max_text_len + 2 - sent_len)
else:
# crop
tokens = tokens[:max_text_len]
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
sent_len = len(tokens)
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
# Random crop
if self.unit_length < 10:
coin2 = np.random.choice(["single", "single", "double"])
else:
coin2 = "single"
if coin2 == "double":
m_length = (m_length // self.unit_length - 1) * self.unit_length
elif coin2 == "single":
m_length = (m_length // self.unit_length) * self.unit_length
idx = random.randint(0, len(motion) - m_length)
motion = motion[idx:idx + m_length]
# Z Normalization
motion = (motion - self.mean) / self.std
return caption, motion, m_length, word_embeddings, pos_one_hots, sent_len, "_".join(
tokens), all_captions