Spaces:
Runtime error
Runtime error
import torch | |
def squeeze(x, nonpadding=None, n_sqz=2): | |
b, c, t = x.size() | |
t = (t // n_sqz) * n_sqz | |
x = x[:, :, :t] | |
x_sqz = x.view(b, c, t // n_sqz, n_sqz) | |
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) | |
if nonpadding is not None: | |
nonpadding = nonpadding[:, :, n_sqz - 1::n_sqz] | |
else: | |
nonpadding = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) | |
return x_sqz * nonpadding, nonpadding | |
def unsqueeze(x, nonpadding=None, n_sqz=2): | |
b, c, t = x.size() | |
x_unsqz = x.view(b, n_sqz, c // n_sqz, t) | |
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) | |
if nonpadding is not None: | |
nonpadding = nonpadding.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) | |
else: | |
nonpadding = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) | |
return x_unsqz * nonpadding, nonpadding | |