Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
def build_word_mask(x2word, y2word): | |
return (x2word[:, :, None] == y2word[:, None, :]).long() | |
def mel2ph_to_mel2word(mel2ph, ph2word): | |
mel2word = (ph2word - 1).gather(1, (mel2ph - 1).clamp(min=0)) + 1 | |
mel2word = mel2word * (mel2ph > 0).long() | |
return mel2word | |
def clip_mel2token_to_multiple(mel2token, frames_multiple): | |
max_frames = mel2token.shape[1] // frames_multiple * frames_multiple | |
mel2token = mel2token[:, :max_frames] | |
return mel2token | |
def expand_states(h, mel2token): | |
h = F.pad(h, [0, 0, 1, 0]) | |
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) | |
h = torch.gather(h, 1, mel2token_) # [B, T, H] | |
return h | |