zyingt's picture
Upload 685 files
0d80816
raw
history blame
860 Bytes
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def slice_segments(x, ids_str, segment_size=200):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_ids_segments(lengths, segment_size=200):
b = lengths.shape[0]
ids_str_max = lengths - segment_size
ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(
dtype=torch.long
)
return ids_str
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
return length
length += 1