Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import einsum | |
from torch.utils.checkpoint import checkpoint | |
from models.arch_util import AttentionBlock | |
from models.xtransformers import ContinuousTransformerWrapper, Encoder | |
def exists(val): | |
return val is not None | |
def masked_mean(t, mask): | |
t = t.masked_fill(~mask, 0.) | |
return t.sum(dim = 1) / mask.sum(dim = 1) | |
class CollapsingTransformer(nn.Module): | |
def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs): | |
super().__init__() | |
self.transformer = ContinuousTransformerWrapper( | |
max_seq_len=-1, | |
use_pos_emb=False, | |
attn_layers=Encoder( | |
dim=model_dim, | |
depth=depth, | |
heads=heads, | |
ff_dropout=dropout, | |
ff_mult=1, | |
attn_dropout=dropout, | |
use_rmsnorm=True, | |
ff_glu=True, | |
rotary_pos_emb=True, | |
**encoder_kwargs, | |
)) | |
self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), | |
AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), | |
nn.Conv1d(output_dims, output_dims, 1)) | |
self.mask_percentage = mask_percentage | |
def forward(self, x, **transformer_kwargs): | |
h = self.transformer(x, **transformer_kwargs) | |
h = h.permute(0,2,1) | |
h = checkpoint(self.pre_combiner, h).permute(0,2,1) | |
if self.training: | |
mask = torch.rand_like(h.float()) > self.mask_percentage | |
else: | |
mask = torch.ones_like(h.float()).bool() | |
return masked_mean(h, mask) | |
class ConvFormatEmbedding(nn.Module): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self.emb = nn.Embedding(*args, **kwargs) | |
def forward(self, x): | |
y = self.emb(x) | |
return y.permute(0,2,1) | |
class CVVP(nn.Module): | |
def __init__( | |
self, | |
model_dim=512, | |
transformer_heads=8, | |
dropout=.1, | |
conditioning_enc_depth=8, | |
cond_mask_percentage=0, | |
mel_channels=80, | |
mel_codes=None, | |
speech_enc_depth=8, | |
speech_mask_percentage=0, | |
latent_multiplier=1, | |
): | |
super().__init__() | |
latent_dim = latent_multiplier*model_dim | |
self.temperature = nn.Parameter(torch.tensor(1.)) | |
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), | |
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) | |
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) | |
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) | |
if mel_codes is None: | |
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) | |
else: | |
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) | |
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) | |
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) | |
def get_grad_norm_parameter_groups(self): | |
return { | |
'conditioning': list(self.conditioning_transformer.parameters()), | |
'speech': list(self.speech_transformer.parameters()), | |
} | |
def forward( | |
self, | |
mel_cond, | |
mel_input, | |
return_loss=False | |
): | |
cond_emb = self.cond_emb(mel_cond).permute(0,2,1) | |
enc_cond = self.conditioning_transformer(cond_emb) | |
cond_latents = self.to_conditioning_latent(enc_cond) | |
speech_emb = self.speech_emb(mel_input).permute(0,2,1) | |
enc_speech = self.speech_transformer(speech_emb) | |
speech_latents = self.to_speech_latent(enc_speech) | |
cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)) | |
temp = self.temperature.exp() | |
if not return_loss: | |
sim = einsum('n d, n d -> n', cond_latents, speech_latents) * temp | |
return sim | |
sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp | |
labels = torch.arange(cond_latents.shape[0], device=mel_input.device) | |
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 | |
return loss | |
if __name__ == '__main__': | |
clvp = CVVP() | |
clvp(torch.randn(2,80,100), | |
torch.randn(2,80,95), | |
return_loss=True) |