|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
import torch |
|
from torch.nn import Module |
|
|
|
from fairseq2.typing import DataType, Device |
|
|
|
from fairseq2.assets import asset_store |
|
from fairseq2.data import ( |
|
Collater, |
|
SequenceData, |
|
VocabularyInfo, |
|
) |
|
from fairseq2.nn.padding import get_seqs_and_padding_mask |
|
|
|
from seamless_communication.inference import BatchedSpeechOutput |
|
from seamless_communication.models.generator.loader import load_pretssel_vocoder_model |
|
|
|
|
|
class PretsselGenerator(Module): |
|
def __init__( |
|
self, |
|
pretssel_name_or_card: str, |
|
vocab_info: VocabularyInfo, |
|
device: Device, |
|
dtype: DataType = torch.float16, |
|
): |
|
super().__init__() |
|
|
|
if device == torch.device("cpu"): |
|
dtype = torch.float32 |
|
|
|
self.device = device |
|
self.dtype = dtype |
|
|
|
self.pretssel_model = load_pretssel_vocoder_model( |
|
pretssel_name_or_card, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
self.pretssel_model.eval() |
|
|
|
vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card) |
|
self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int) |
|
|
|
self.vocab_info = vocab_info |
|
self.unit_collate = Collater(pad_value=vocab_info.pad_idx) |
|
self.duration_collate = Collater(pad_value=0) |
|
self.unit_eos_token = torch.tensor([vocab_info.eos_idx], device=device) |
|
|
|
@torch.inference_mode() |
|
def predict( |
|
self, |
|
units: List[List[int]], |
|
tgt_lang: str, |
|
prosody_encoder_input: SequenceData, |
|
) -> BatchedSpeechOutput: |
|
|
|
units_batch, durations = [], [] |
|
for u in units: |
|
unit = torch.tensor(u).to(self.unit_eos_token) |
|
|
|
|
|
unit += 4 |
|
unit = torch.cat([unit, self.unit_eos_token], dim=0) |
|
|
|
unit, duration = torch.unique_consecutive(unit, return_counts=True) |
|
|
|
|
|
duration[-1] = 0 |
|
|
|
units_batch.append(unit) |
|
durations.append(duration * 2) |
|
|
|
speech_units = self.unit_collate(units_batch) |
|
durations = self.duration_collate(durations)["seqs"] |
|
|
|
units_tensor, unit_padding_mask = get_seqs_and_padding_mask(speech_units) |
|
prosody_input_seqs, prosody_padding_mask = get_seqs_and_padding_mask( |
|
prosody_encoder_input |
|
) |
|
|
|
audio_wavs = self.pretssel_model( |
|
units_tensor, |
|
tgt_lang, |
|
prosody_input_seqs, |
|
padding_mask=unit_padding_mask, |
|
prosody_padding_mask=prosody_padding_mask, |
|
durations=durations, |
|
) |
|
return BatchedSpeechOutput( |
|
units=units, |
|
audio_wavs=audio_wavs, |
|
sample_rate=self.output_sample_rate, |
|
) |
|
|