|
|
from typing import Any |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class ContentEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
text_encoder: nn.Module = None, |
|
|
video_encoder: nn.Module = None, |
|
|
midi_encoder: nn.Module = None, |
|
|
phoneme_encoder: nn.Module = None, |
|
|
pitch_encoder: nn.Module = None, |
|
|
audio_encoder: nn.Module = None, |
|
|
speech_encoder: nn.Module = None, |
|
|
sketch_encoder: nn.Module = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.text_encoder = text_encoder |
|
|
self.midi_encoder = midi_encoder |
|
|
self.phoneme_encoder = phoneme_encoder |
|
|
self.pitch_encoder = pitch_encoder |
|
|
self.audio_encoder = audio_encoder |
|
|
self.video_encoder = video_encoder |
|
|
self.speech_encoder = speech_encoder |
|
|
self.sketch_encoder = sketch_encoder |
|
|
|
|
|
def encode_content( |
|
|
self, batch_content: list[Any], batch_task: list[str], |
|
|
device: str | torch.device |
|
|
): |
|
|
batch_content_output = [] |
|
|
batch_content_mask = [] |
|
|
batch_la_content_output = [] |
|
|
|
|
|
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) |
|
|
|
|
|
for content, task in zip(batch_content, batch_task): |
|
|
if task == "audio_super_resolution" or task == "speech_enhancement": |
|
|
content_dict = { |
|
|
"waveform": torch.as_tensor(content).float(), |
|
|
"waveform_lengths": torch.as_tensor(content.shape[0]), |
|
|
} |
|
|
for key in list(content_dict.keys()): |
|
|
content_dict[key] = content_dict[key].unsqueeze(0).to( |
|
|
device |
|
|
) |
|
|
content_output_dict = self.audio_encoder(**content_dict) |
|
|
la_content_output_dict = { |
|
|
"output": zero_la_content, |
|
|
} |
|
|
elif task == "text_to_audio" or task == "text_to_music": |
|
|
content_output_dict = self.text_encoder([content]) |
|
|
la_content_output_dict = { |
|
|
"output": zero_la_content, |
|
|
} |
|
|
elif task == "speech_to_audio": |
|
|
input_dict = { |
|
|
"embed": content, |
|
|
"embed_len": torch.tensor([content.shape[1]], dtype=torch.int).to(device), |
|
|
} |
|
|
content_output_dict = self.speech_encoder(input_dict) |
|
|
la_content_output_dict = { |
|
|
"output": zero_la_content, |
|
|
} |
|
|
elif task == "direct_speech_to_audio": |
|
|
|
|
|
if len(content.shape) < 3: |
|
|
content = content.unsqueeze(0) |
|
|
mask = torch.ones(content.shape[:2]) |
|
|
mask = (mask == 1).to(content.device) |
|
|
content_output_dict = { |
|
|
"output": content, |
|
|
"mask": mask, |
|
|
} |
|
|
la_content_output_dict = { |
|
|
"output": zero_la_content, |
|
|
} |
|
|
elif task == "sketch_to_audio": |
|
|
content_output_dict = self.sketch_encoder([content["caption"]]) |
|
|
content_dict = { |
|
|
"f0": torch.as_tensor(content["f0"]), |
|
|
"energy": torch.as_tensor(content["energy"]), |
|
|
} |
|
|
for key in list(content_dict.keys()): |
|
|
content_dict[key] = content_dict[key].unsqueeze(0).to( |
|
|
device |
|
|
) |
|
|
la_content_output_dict = self.sketch_encoder.encode_sketch( |
|
|
**content_dict |
|
|
) |
|
|
elif task == "video_to_audio": |
|
|
content_dict = { |
|
|
"frames": torch.as_tensor(content).float(), |
|
|
"frame_nums": torch.as_tensor(content.shape[0]), |
|
|
} |
|
|
for key in list(content_dict.keys()): |
|
|
content_dict[key] = content_dict[key].unsqueeze(0).to( |
|
|
device |
|
|
) |
|
|
content_output_dict = self.video_encoder(**content_dict) |
|
|
la_content_output_dict = { |
|
|
"output": zero_la_content, |
|
|
} |
|
|
elif task == "singing_voice_synthesis": |
|
|
content_dict = { |
|
|
"phoneme": |
|
|
torch.as_tensor(content["phoneme"]).long(), |
|
|
"midi": |
|
|
torch.as_tensor(content["midi"]).long(), |
|
|
"midi_duration": |
|
|
torch.as_tensor(content["midi_duration"]).float(), |
|
|
"is_slur": |
|
|
torch.as_tensor(content["is_slur"]).long() |
|
|
} |
|
|
if "spk" in content: |
|
|
if self.midi_encoder.spk_config.encoding_format == "id": |
|
|
content_dict["spk"] = torch.as_tensor(content["spk"] |
|
|
).long() |
|
|
elif self.midi_encoder.spk_config.encoding_format == "embedding": |
|
|
content_dict["spk"] = torch.as_tensor(content["spk"] |
|
|
).float() |
|
|
for key in list(content_dict.keys()): |
|
|
content_dict[key] = content_dict[key].unsqueeze(0).to( |
|
|
device |
|
|
) |
|
|
content_dict["lengths"] = torch.as_tensor([ |
|
|
len(content["phoneme"]) |
|
|
]) |
|
|
content_output_dict = self.midi_encoder(**content_dict) |
|
|
la_content_output_dict = {"output": zero_la_content} |
|
|
elif task == "text_to_speech": |
|
|
content_dict = { |
|
|
"phoneme": torch.as_tensor(content["phoneme"]).long(), |
|
|
} |
|
|
if "spk" in content: |
|
|
if self.phoneme_encoder.spk_config.encoding_format == "id": |
|
|
content_dict["spk"] = torch.as_tensor(content["spk"] |
|
|
).long() |
|
|
elif self.phoneme_encoder.spk_config.encoding_format == "embedding": |
|
|
content_dict["spk"] = torch.as_tensor(content["spk"] |
|
|
).float() |
|
|
for key in list(content_dict.keys()): |
|
|
content_dict[key] = content_dict[key].unsqueeze(0).to( |
|
|
device |
|
|
) |
|
|
content_dict["lengths"] = torch.as_tensor([ |
|
|
len(content["phoneme"]) |
|
|
]) |
|
|
content_output_dict = self.phoneme_encoder(**content_dict) |
|
|
la_content_output_dict = {"output": zero_la_content} |
|
|
elif task == "singing_acoustic_modeling": |
|
|
content_dict = { |
|
|
"phoneme": torch.as_tensor(content["phoneme"]).long(), |
|
|
} |
|
|
for key in list(content_dict.keys()): |
|
|
content_dict[key] = content_dict[key].unsqueeze(0).to( |
|
|
device |
|
|
) |
|
|
content_dict["lengths"] = torch.as_tensor([ |
|
|
len(content["phoneme"]) |
|
|
]) |
|
|
content_output_dict = self.pitch_encoder(**content_dict) |
|
|
|
|
|
content_dict = { |
|
|
"f0": torch.as_tensor(content["f0"]), |
|
|
"uv": torch.as_tensor(content["uv"]), |
|
|
} |
|
|
for key in list(content_dict.keys()): |
|
|
content_dict[key] = content_dict[key].unsqueeze(0).to( |
|
|
device |
|
|
) |
|
|
la_content_output_dict = self.pitch_encoder.encode_pitch( |
|
|
**content_dict |
|
|
) |
|
|
|
|
|
batch_content_output.append(content_output_dict["output"][0]) |
|
|
batch_content_mask.append(content_output_dict["mask"][0]) |
|
|
batch_la_content_output.append(la_content_output_dict["output"][0]) |
|
|
|
|
|
batch_content_output = nn.utils.rnn.pad_sequence( |
|
|
batch_content_output, batch_first=True, padding_value=0 |
|
|
) |
|
|
batch_content_mask = nn.utils.rnn.pad_sequence( |
|
|
batch_content_mask, batch_first=True, padding_value=False |
|
|
) |
|
|
batch_la_content_output = nn.utils.rnn.pad_sequence( |
|
|
batch_la_content_output, batch_first=True, padding_value=0 |
|
|
) |
|
|
return { |
|
|
"content": batch_content_output, |
|
|
"content_mask": batch_content_mask, |
|
|
"length_aligned_content": batch_la_content_output, |
|
|
} |
|
|
|
|
|
|
|
|
class BatchedContentEncoder(ContentEncoder): |
|
|
def encode_content( |
|
|
self, batch_content: list | dict, batch_task: list[str], |
|
|
device: str | torch.device |
|
|
): |
|
|
task = batch_task[0] |
|
|
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) |
|
|
if task == "audio_super_resolution" or task == "speech_enhancement": |
|
|
content_dict = { |
|
|
"waveform": |
|
|
batch_content["content"].unsqueeze(1).float().to(device), |
|
|
"waveform_lengths": |
|
|
batch_content["content_lengths"].long().to(device), |
|
|
} |
|
|
content_output = self.audio_encoder(**content_dict) |
|
|
la_content_output = zero_la_content |
|
|
elif task == "text_to_audio": |
|
|
content_output = self.text_encoder(batch_content) |
|
|
la_content_output = zero_la_content |
|
|
elif task == "video_to_audio": |
|
|
content_dict = { |
|
|
"frames": |
|
|
batch_content["content"].float().to(device), |
|
|
"frame_nums": |
|
|
batch_content["content_lengths"].long().to(device), |
|
|
} |
|
|
content_output = self.video_encoder(**content_dict) |
|
|
la_content_output = zero_la_content |
|
|
elif task == "singing_voice_synthesis": |
|
|
content_dict = { |
|
|
"phoneme": |
|
|
batch_content["phoneme"].long().to(device), |
|
|
"midi": |
|
|
batch_content["midi"].long().to(device), |
|
|
"midi_duration": |
|
|
batch_content["midi_duration"].float().to(device), |
|
|
"is_slur": |
|
|
batch_content["is_slur"].long().to(device), |
|
|
"lengths": |
|
|
batch_content["phoneme_lengths"].long().cpu(), |
|
|
} |
|
|
if "spk" in batch_content: |
|
|
if self.midi_encoder.spk_config.encoding_format == "id": |
|
|
content_dict["spk"] = batch_content["spk"].long( |
|
|
).to(device) |
|
|
elif self.midi_encoder.spk_config.encoding_format == "embedding": |
|
|
content_dict["spk"] = batch_content["spk"].float( |
|
|
).to(device) |
|
|
content_output = self.midi_encoder(**content_dict) |
|
|
la_content_output = zero_la_content |
|
|
elif task == "text_to_speech": |
|
|
content_dict = { |
|
|
"phoneme": batch_content["phoneme"].long().to(device), |
|
|
"lengths": batch_content["phoneme_lengths"].long().cpu(), |
|
|
} |
|
|
if "spk" in batch_content: |
|
|
if self.phoneme_encoder.spk_config.encoding_format == "id": |
|
|
content_dict["spk"] = batch_content["spk"].long( |
|
|
).to(device) |
|
|
elif self.phoneme_encoder.spk_config.encoding_format == "embedding": |
|
|
content_dict["spk"] = batch_content["spk"].float( |
|
|
).to(device) |
|
|
content_output = self.phoneme_encoder(**content_dict) |
|
|
la_content_output = zero_la_content |
|
|
elif task == "singing_acoustic_modeling": |
|
|
content_dict = { |
|
|
"phoneme": batch_content["phoneme"].long().to(device), |
|
|
"lengths": batch_content["phoneme_lengths"].long().to(device), |
|
|
} |
|
|
content_output = self.pitch_encoder(**content_dict) |
|
|
|
|
|
content_dict = { |
|
|
"f0": batch_content["f0"].float().to(device), |
|
|
"uv": batch_content["uv"].float().to(device), |
|
|
} |
|
|
la_content_output = self.pitch_encoder.encode_pitch(**content_dict) |
|
|
|
|
|
return { |
|
|
"content": content_output["output"], |
|
|
"content_mask": content_output["mask"], |
|
|
"length_aligned_content": la_content_output, |
|
|
} |
|
|
|