MotionGPT / mGPT /utils /tensors.py
bill-jiang's picture
Init
4409449
raw
history blame
No virus
2.46 kB
import torch
def lengths_to_mask(lengths):
max_len = max(lengths)
mask = torch.arange(max_len, device=lengths.device).expand(
len(lengths), max_len) < lengths.unsqueeze(1)
return mask
def collate_tensors(batch):
dims = batch[0].dim()
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
size = (len(batch),) + tuple(max_size)
canvas = batch[0].new_zeros(size=size)
for i, b in enumerate(batch):
sub_tensor = canvas[i]
for d in range(dims):
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
sub_tensor.add_(b)
return canvas
def collate(batch):
databatch = [b[0] for b in batch]
labelbatch = [b[1] for b in batch]
lenbatch = [len(b[0][0][0]) for b in batch]
databatchTensor = collate_tensors(databatch)
labelbatchTensor = torch.as_tensor(labelbatch)
lenbatchTensor = torch.as_tensor(lenbatch)
maskbatchTensor = lengths_to_mask(lenbatchTensor)
# x - [bs, njoints, nfeats, lengths]
# - nfeats, the representation of a joint
# y - [bs]
# mask - [bs, lengths]
# lengths - [bs]
batch = {"x": databatchTensor, "y": labelbatchTensor,
"mask": maskbatchTensor, 'lengths': lenbatchTensor}
return batch
# slow version with padding
def collate_data3d_slow(batch):
batchTensor = {}
for key in batch[0].keys():
databatch = [b[key] for b in batch]
batchTensor[key] = collate_tensors(databatch)
batch = batchTensor
# theta - [bs, lengths, 85], theta shape (85,)
# - (np.array([1., 0., 0.]), pose(72), shape(10)), axis=0)
# kp_2d - [bs, lengths, njoints, nfeats], nfeats (x,y,weight)
# kp_3d - [bs, lengths, njoints, nfeats], nfeats (x,y,z)
# w_smpl - [bs, lengths] zeros
# w_3d - [bs, lengths] zeros
return batch
def collate_data3d(batch):
batchTensor = {}
for key in batch[0].keys():
databatch = [b[key] for b in batch]
if key == "paths":
batchTensor[key] = databatch
else:
batchTensor[key] = torch.stack(databatch,axis=0)
batch = batchTensor
# theta - [bs, lengths, 85], theta shape (85,)
# - (np.array([1., 0., 0.]), pose(72), shape(10)), axis=0)
# kp_2d - [bs, lengths, njoints, nfeats], nfeats (x,y,weight)
# kp_3d - [bs, lengths, njoints, nfeats], nfeats (x,y,z)
# w_smpl - [bs, lengths] zeros
# w_3d - [bs, lengths] zeros
return batch