M4Singer / utils /tts_utils.py
kevinwang676's picture
Duplicate from zlc99/M4Singer
26925fd
raw
history blame
897 Bytes
import torch
import torch.nn.functional as F
from collections import defaultdict
def make_positions(tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + padding_idx
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def softmax(x, dim):
return F.softmax(x, dim=dim, dtype=torch.float32)