tortoise-tts-v2 / tortoise /models /random_latent_generator.py
jbetker's picture
Support totally random voices (and make fixes to previous changes)
39ec1b0
raw
history blame contribute delete
No virus
1.65 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
return out
class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
nn.Linear(channels, channels))
self.channels = channels
def forward(self, ref):
r = torch.randn(ref.shape[0], self.channels, device=ref.device)
y = self.layers(r)
return y
if __name__ == '__main__':
model = RandomLatentConverter(512)
model(torch.randn(5,512))