from typing import Any import torch import torch.nn as nn class ContentEncoder(nn.Module): def __init__( self, embed_dim: int, text_encoder: nn.Module = None, video_encoder: nn.Module = None, midi_encoder: nn.Module = None, phoneme_encoder: nn.Module = None, pitch_encoder: nn.Module = None, audio_encoder: nn.Module = None, speech_encoder: nn.Module = None, sketch_encoder: nn.Module = None, ): super().__init__() self.embed_dim = embed_dim self.text_encoder = text_encoder self.midi_encoder = midi_encoder self.phoneme_encoder = phoneme_encoder self.pitch_encoder = pitch_encoder self.audio_encoder = audio_encoder self.video_encoder = video_encoder self.speech_encoder = speech_encoder self.sketch_encoder = sketch_encoder def encode_content( self, batch_content: list[Any], batch_task: list[str], device: str | torch.device ): batch_content_output = [] batch_content_mask = [] batch_la_content_output = [] zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) for content, task in zip(batch_content, batch_task): if task == "audio_super_resolution" or task == "speech_enhancement": content_dict = { "waveform": torch.as_tensor(content).float(), "waveform_lengths": torch.as_tensor(content.shape[0]), } for key in list(content_dict.keys()): content_dict[key] = content_dict[key].unsqueeze(0).to( device ) content_output_dict = self.audio_encoder(**content_dict) la_content_output_dict = { "output": zero_la_content, } elif task == "text_to_audio" or task == "text_to_music": content_output_dict = self.text_encoder([content]) la_content_output_dict = { "output": zero_la_content, } elif task == "speech_to_audio": input_dict = { "embed": content, "embed_len": torch.tensor([content.shape[1]], dtype=torch.int).to(device), } content_output_dict = self.speech_encoder(input_dict) la_content_output_dict = { "output": zero_la_content, } elif task == "direct_speech_to_audio": # content shape [1, L/T 133, dim] mask [1, L/T 133] in hubert if len(content.shape) < 3: content = content.unsqueeze(0) mask = torch.ones(content.shape[:2]) mask = (mask == 1).to(content.device) content_output_dict = { "output": content, "mask": mask, } la_content_output_dict = { "output": zero_la_content, } elif task == "sketch_to_audio": content_output_dict = self.sketch_encoder([content["caption"]]) content_dict = { "f0": torch.as_tensor(content["f0"]), "energy": torch.as_tensor(content["energy"]), } for key in list(content_dict.keys()): content_dict[key] = content_dict[key].unsqueeze(0).to( device ) la_content_output_dict = self.sketch_encoder.encode_sketch( **content_dict ) elif task == "video_to_audio": content_dict = { "frames": torch.as_tensor(content).float(), "frame_nums": torch.as_tensor(content.shape[0]), } for key in list(content_dict.keys()): content_dict[key] = content_dict[key].unsqueeze(0).to( device ) content_output_dict = self.video_encoder(**content_dict) la_content_output_dict = { "output": zero_la_content, } elif task == "singing_voice_synthesis": content_dict = { "phoneme": torch.as_tensor(content["phoneme"]).long(), "midi": torch.as_tensor(content["midi"]).long(), "midi_duration": torch.as_tensor(content["midi_duration"]).float(), "is_slur": torch.as_tensor(content["is_slur"]).long() } if "spk" in content: if self.midi_encoder.spk_config.encoding_format == "id": content_dict["spk"] = torch.as_tensor(content["spk"] ).long() elif self.midi_encoder.spk_config.encoding_format == "embedding": content_dict["spk"] = torch.as_tensor(content["spk"] ).float() for key in list(content_dict.keys()): content_dict[key] = content_dict[key].unsqueeze(0).to( device ) content_dict["lengths"] = torch.as_tensor([ len(content["phoneme"]) ]) content_output_dict = self.midi_encoder(**content_dict) la_content_output_dict = {"output": zero_la_content} elif task == "text_to_speech": content_dict = { "phoneme": torch.as_tensor(content["phoneme"]).long(), } if "spk" in content: if self.phoneme_encoder.spk_config.encoding_format == "id": content_dict["spk"] = torch.as_tensor(content["spk"] ).long() elif self.phoneme_encoder.spk_config.encoding_format == "embedding": content_dict["spk"] = torch.as_tensor(content["spk"] ).float() for key in list(content_dict.keys()): content_dict[key] = content_dict[key].unsqueeze(0).to( device ) content_dict["lengths"] = torch.as_tensor([ len(content["phoneme"]) ]) content_output_dict = self.phoneme_encoder(**content_dict) la_content_output_dict = {"output": zero_la_content} elif task == "singing_acoustic_modeling": content_dict = { "phoneme": torch.as_tensor(content["phoneme"]).long(), } for key in list(content_dict.keys()): content_dict[key] = content_dict[key].unsqueeze(0).to( device ) content_dict["lengths"] = torch.as_tensor([ len(content["phoneme"]) ]) content_output_dict = self.pitch_encoder(**content_dict) content_dict = { "f0": torch.as_tensor(content["f0"]), "uv": torch.as_tensor(content["uv"]), } for key in list(content_dict.keys()): content_dict[key] = content_dict[key].unsqueeze(0).to( device ) la_content_output_dict = self.pitch_encoder.encode_pitch( **content_dict ) batch_content_output.append(content_output_dict["output"][0]) batch_content_mask.append(content_output_dict["mask"][0]) batch_la_content_output.append(la_content_output_dict["output"][0]) batch_content_output = nn.utils.rnn.pad_sequence( batch_content_output, batch_first=True, padding_value=0 ) batch_content_mask = nn.utils.rnn.pad_sequence( batch_content_mask, batch_first=True, padding_value=False ) batch_la_content_output = nn.utils.rnn.pad_sequence( batch_la_content_output, batch_first=True, padding_value=0 ) return { "content": batch_content_output, "content_mask": batch_content_mask, "length_aligned_content": batch_la_content_output, } class BatchedContentEncoder(ContentEncoder): def encode_content( self, batch_content: list | dict, batch_task: list[str], device: str | torch.device ): task = batch_task[0] zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) if task == "audio_super_resolution" or task == "speech_enhancement": content_dict = { "waveform": batch_content["content"].unsqueeze(1).float().to(device), "waveform_lengths": batch_content["content_lengths"].long().to(device), } content_output = self.audio_encoder(**content_dict) la_content_output = zero_la_content elif task == "text_to_audio": content_output = self.text_encoder(batch_content) la_content_output = zero_la_content elif task == "video_to_audio": content_dict = { "frames": batch_content["content"].float().to(device), "frame_nums": batch_content["content_lengths"].long().to(device), } content_output = self.video_encoder(**content_dict) la_content_output = zero_la_content elif task == "singing_voice_synthesis": content_dict = { "phoneme": batch_content["phoneme"].long().to(device), "midi": batch_content["midi"].long().to(device), "midi_duration": batch_content["midi_duration"].float().to(device), "is_slur": batch_content["is_slur"].long().to(device), "lengths": batch_content["phoneme_lengths"].long().cpu(), } if "spk" in batch_content: if self.midi_encoder.spk_config.encoding_format == "id": content_dict["spk"] = batch_content["spk"].long( ).to(device) elif self.midi_encoder.spk_config.encoding_format == "embedding": content_dict["spk"] = batch_content["spk"].float( ).to(device) content_output = self.midi_encoder(**content_dict) la_content_output = zero_la_content elif task == "text_to_speech": content_dict = { "phoneme": batch_content["phoneme"].long().to(device), "lengths": batch_content["phoneme_lengths"].long().cpu(), } if "spk" in batch_content: if self.phoneme_encoder.spk_config.encoding_format == "id": content_dict["spk"] = batch_content["spk"].long( ).to(device) elif self.phoneme_encoder.spk_config.encoding_format == "embedding": content_dict["spk"] = batch_content["spk"].float( ).to(device) content_output = self.phoneme_encoder(**content_dict) la_content_output = zero_la_content elif task == "singing_acoustic_modeling": content_dict = { "phoneme": batch_content["phoneme"].long().to(device), "lengths": batch_content["phoneme_lengths"].long().to(device), } content_output = self.pitch_encoder(**content_dict) content_dict = { "f0": batch_content["f0"].float().to(device), "uv": batch_content["uv"].float().to(device), } la_content_output = self.pitch_encoder.encode_pitch(**content_dict) return { "content": content_output["output"], "content_mask": content_output["mask"], "length_aligned_content": la_content_output, }