import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum from torch.utils.checkpoint import checkpoint from models.arch_util import AttentionBlock from models.xtransformers import ContinuousTransformerWrapper, Encoder def exists(val): return val is not None def masked_mean(t, mask): t = t.masked_fill(~mask, 0.) return t.sum(dim = 1) / mask.sum(dim = 1) class CollapsingTransformer(nn.Module): def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs): super().__init__() self.transformer = ContinuousTransformerWrapper( max_seq_len=-1, use_pos_emb=False, attn_layers=Encoder( dim=model_dim, depth=depth, heads=heads, ff_dropout=dropout, ff_mult=1, attn_dropout=dropout, use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, **encoder_kwargs, )) self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), nn.Conv1d(output_dims, output_dims, 1)) self.mask_percentage = mask_percentage def forward(self, x, **transformer_kwargs): h = self.transformer(x, **transformer_kwargs) h = h.permute(0,2,1) h = checkpoint(self.pre_combiner, h).permute(0,2,1) if self.training: mask = torch.rand_like(h.float()) > self.mask_percentage else: mask = torch.ones_like(h.float()).bool() return masked_mean(h, mask) class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.emb = nn.Embedding(*args, **kwargs) def forward(self, x): y = self.emb(x) return y.permute(0,2,1) class CVVP(nn.Module): def __init__( self, model_dim=512, transformer_heads=8, dropout=.1, conditioning_enc_depth=8, cond_mask_percentage=0, mel_channels=80, mel_codes=None, speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1, ): super().__init__() latent_dim = latent_multiplier*model_dim self.temperature = nn.Parameter(torch.tensor(1.)) self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) if mel_codes is None: self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): return { 'conditioning': list(self.conditioning_transformer.parameters()), 'speech': list(self.speech_transformer.parameters()), } def forward( self, mel_cond, mel_input, return_loss=False ): cond_emb = self.cond_emb(mel_cond).permute(0,2,1) enc_cond = self.conditioning_transformer(cond_emb) cond_latents = self.to_conditioning_latent(enc_cond) speech_emb = self.speech_emb(mel_input).permute(0,2,1) enc_speech = self.speech_transformer(speech_emb) speech_latents = self.to_speech_latent(enc_speech) cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)) temp = self.temperature.exp() if not return_loss: sim = einsum('n d, n d -> n', cond_latents, speech_latents) * temp return sim sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp labels = torch.arange(cond_latents.shape[0], device=mel_input.device) loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 return loss if __name__ == '__main__': clvp = CVVP() clvp(torch.randn(2,80,100), torch.randn(2,80,95), return_loss=True)