Spaces:
Runtime error
Runtime error
| import torch | |
| from Modules.EmbeddingModel.GST import GSTStyleEncoder | |
| from Modules.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder | |
| class StyleEmbedding(torch.nn.Module): | |
| """ | |
| The style embedding should provide information of the speaker and their speaking style | |
| The feedback signal for the module will come from the TTS objective, so it doesn't have a dedicated train loop. | |
| The train loop does however supply supervision in the form of a barlow twins objective. | |
| See the git history for some other approaches for style embedding, like the SWIN transformer | |
| and a simple LSTM baseline. GST turned out to be the best. | |
| """ | |
| def __init__(self, embedding_dim=16, style_tts_encoder=False): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.use_gst = not style_tts_encoder | |
| if style_tts_encoder: | |
| self.style_encoder = StyleTTSEncoder(style_dim=embedding_dim) | |
| else: | |
| self.style_encoder = GSTStyleEncoder(gst_token_dim=embedding_dim) | |
| def forward(self, | |
| batch_of_feature_sequences, | |
| batch_of_feature_sequence_lengths): | |
| """ | |
| Args: | |
| batch_of_feature_sequences: b is the batch axis, 128 features per timestep | |
| and l time-steps, which may include padding | |
| for most elements in the batch (b, l, 128) | |
| batch_of_feature_sequence_lengths: indicate for every element in the batch, | |
| what the true length is, since they are | |
| all padded to the length of the longest | |
| element in the batch (b, 1) | |
| Returns: | |
| batch of n dimensional embeddings (b,n) | |
| """ | |
| minimum_sequence_length = 512 | |
| specs = list() | |
| for index, spec_length in enumerate(batch_of_feature_sequence_lengths): | |
| spec = batch_of_feature_sequences[index][:spec_length] | |
| # double the length at least once, then check | |
| spec = spec.repeat((2, 1)) | |
| current_spec_length = len(spec) | |
| while current_spec_length < minimum_sequence_length: | |
| # make it longer | |
| spec = spec.repeat((2, 1)) | |
| current_spec_length = len(spec) | |
| specs.append(spec[:minimum_sequence_length]) | |
| spec_batch = torch.stack(specs, dim=0) | |
| return self.style_encoder(speech=spec_batch) | |
| if __name__ == '__main__': | |
| style_emb = StyleEmbedding(style_tts_encoder=False) | |
| print(f"GST parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}") | |
| seq_length = 398 | |
| print(style_emb(torch.randn(5, seq_length, 512), | |
| torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape) | |
| style_emb = StyleEmbedding(style_tts_encoder=True) | |
| print(f"StyleTTS encoder parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}") | |
| seq_length = 398 | |
| print(style_emb(torch.randn(5, seq_length, 512), | |
| torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape) | |