akhaliq's picture
akhaliq HF staff
add files
c80917c
raw
history blame
760 Bytes
import torch
def repeat_tensors(n, x):
"""
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
For collections, do nested repeat
"""
if torch.is_tensor(x):
x = x.unsqueeze(1) # Bx1x...
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
elif type(x) is list or type(x) is tuple:
x = [repeat_tensors(n, _) for _ in x]
return x
def split_tensors(n, x):
if torch.is_tensor(x):
assert x.shape[0] % n == 0
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
elif type(x) is list or type(x) is tuple:
x = [split_tensors(n, _) for _ in x]
elif x is None:
x = [None] * n
return x