Spaces:
Runtime error
Runtime error
# ---------------------------------------------------------------------------- | |
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
from pathlib import Path | |
from typing import List, Dict, Optional, Any | |
from dataclasses import dataclass | |
import numpy as np | |
import torch | |
from fairseq.data.audio.speech_to_text_dataset import ( | |
SpeechToTextDataset, | |
SpeechToTextDatasetCreator, | |
S2TDataConfig, | |
_collate_frames, | |
get_features_or_waveform, | |
) | |
from fairseq.data import Dictionary, data_utils as fairseq_data_utils | |
class TextToUnitDatasetItem(object): | |
index: int | |
source: torch.Tensor | |
target: Optional[torch.Tensor] = None | |
speaker_id: Optional[int] = None | |
speaker_emb: Optional[torch.Tensor] = None | |
duration: Optional[torch.Tensor] = None | |
pitch: Optional[torch.Tensor] = None | |
energy: Optional[torch.Tensor] = None | |
class Text2UnitDataset(SpeechToTextDataset): | |
def __init__( | |
self, | |
split: str, | |
is_train_split: bool, | |
cfg: S2TDataConfig, | |
unit_labels: List[str], | |
n_frames: List[int], | |
src_texts: Optional[List[str]] = None, | |
tgt_texts: Optional[List[str]] = None, | |
speakers: Optional[List[str]] = None, | |
src_langs: Optional[List[str]] = None, | |
tgt_langs: Optional[List[str]] = None, | |
ids: Optional[List[str]] = None, | |
tgt_dict: Optional[Dictionary] = None, | |
pre_tokenizer=None, | |
bpe_tokenizer=None, | |
n_frames_per_step=1, | |
speaker_to_id=None, | |
durations: Optional[List[List[int]]] = None, | |
pitches: Optional[List[str]] = None, | |
energies: Optional[List[str]] = None, | |
): | |
super(Text2UnitDataset, self).__init__( | |
split, | |
is_train_split, | |
cfg, | |
unit_labels, | |
n_frames, | |
src_texts=src_texts, | |
tgt_texts=tgt_texts, | |
speakers=speakers, | |
src_langs=src_langs, | |
tgt_langs=tgt_langs, | |
ids=ids, | |
tgt_dict=tgt_dict, | |
pre_tokenizer=pre_tokenizer, | |
bpe_tokenizer=bpe_tokenizer, | |
n_frames_per_step=n_frames_per_step, | |
speaker_to_id=speaker_to_id, | |
) | |
self.durations = durations | |
self.pitches = pitches | |
self.energies = energies | |
self.unit_labels = unit_labels | |
self.feature_root = Path(cfg.audio_root) | |
self.spk_emb_type = cfg.config.get("speaker_embedding_type", None) | |
self.random_spk = cfg.config.get("random_speaker", False) | |
if self.spk_emb_type is not None: | |
self.spk_emb_choices = [i for i in (self.feature_root / self.spk_emb_type).glob("*.npy")] | |
self.spk_emb_num = len(self.spk_emb_choices) | |
def __getitem__(self, index: int) -> TextToUnitDatasetItem: | |
# s2t_item = super().__getitem__(index) | |
source = torch.LongTensor(self.unit_labels[index]) | |
target = None | |
if self.tgt_texts is not None: | |
tokenized = self.get_tokenized_tgt_text(index) | |
target = self.tgt_dict.encode_line( | |
tokenized, add_if_not_exist=False, append_eos=self.append_eos | |
).long() | |
if self.cfg.prepend_tgt_lang_tag: | |
lang_tag_idx = self.get_lang_tag_idx( | |
self.tgt_langs[index], self.tgt_dict | |
) | |
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) | |
speaker_id = None | |
if self.speaker_to_id is not None: | |
speaker_id = self.speaker_to_id[self.speakers[index]] | |
speaker_emb = None | |
if self.spk_emb_type is not None: | |
if self.random_spk: | |
spk_emb_path = self.spk_emb_choices[np.random.choice(self.spk_emb_num)] | |
else: | |
spk_emb_path = self.feature_root / self.spk_emb_type / f"{self.ids[index]}.npy" | |
speaker_emb = get_features_or_waveform(spk_emb_path) | |
speaker_emb = torch.from_numpy(speaker_emb).float() | |
duration, pitch, energy = None, None, None | |
if self.durations is not None: | |
duration = torch.tensor( | |
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS | |
) | |
if self.pitches is not None: | |
pitch = get_features_or_waveform(self.pitches[index]) | |
pitch = torch.from_numpy( | |
np.concatenate((pitch, [0])) # pad 0 for EOS | |
).float() | |
if self.energies is not None: | |
energy = get_features_or_waveform(self.energies[index]) | |
energy = torch.from_numpy( | |
np.concatenate((energy, [0])) # pad 0 for EOS | |
).float() | |
return TextToUnitDatasetItem( | |
index=index, | |
source=source, | |
target=target, | |
speaker_id=speaker_id, | |
speaker_emb=speaker_emb, | |
duration=duration, | |
pitch=pitch, | |
energy=energy, | |
) | |
def collater(self, samples: List[TextToUnitDatasetItem]) -> Dict[str, Any]: | |
if len(samples) == 0: | |
return {} | |
src_lengths, order = torch.tensor( | |
[s.target.shape[0] for s in samples], dtype=torch.long | |
).sort(descending=True) | |
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select( | |
0, order | |
) | |
traget = fairseq_data_utils.collate_tokens( | |
[s.source for s in samples], | |
self.tgt_dict.pad(), | |
).index_select(0, order) | |
target_lengths = torch.tensor( | |
[s.source.shape[0] for s in samples], dtype=torch.long | |
).index_select(0, order) | |
src_tokens = fairseq_data_utils.collate_tokens( | |
[s.target for s in samples], | |
self.tgt_dict.pad(), | |
self.tgt_dict.eos(), | |
left_pad=False, | |
move_eos_to_beginning=False, | |
).index_select(0, order) | |
speaker = None | |
if self.speaker_to_id is not None: | |
speaker = ( | |
torch.tensor([s.speaker_id for s in samples], dtype=torch.long) | |
.index_select(0, order) | |
.view(-1, 1) | |
) | |
if self.spk_emb_type is not None: | |
speaker = torch.stack([s.speaker_emb for s in samples], dim=0).index_select(0, order) | |
bsz, _ = traget.size() | |
prev_output_tokens = torch.cat( | |
(traget.new_zeros((bsz, self.tgt_dict.bos())), traget[:, :-1]), dim=1 | |
) | |
durations, pitches, energies = None, None, None | |
if self.durations is not None: | |
durations = fairseq_data_utils.collate_tokens( | |
[s.duration for s in samples], 0 | |
).index_select(0, order) | |
assert src_tokens.shape[1] == durations.shape[1] | |
if self.pitches is not None: | |
pitches = _collate_frames([s.pitch for s in samples], True) | |
pitches = pitches.index_select(0, order) | |
assert src_tokens.shape[1] == pitches.shape[1] | |
if self.energies is not None: | |
energies = _collate_frames([s.energy for s in samples], True) | |
energies = energies.index_select(0, order) | |
assert src_tokens.shape[1] == energies.shape[1] | |
src_texts = [self.tgt_dict.string(samples[i].target) for i in order] | |
return { | |
"id": id_, | |
"net_input": { | |
"src_tokens": src_tokens, | |
"src_lengths": src_lengths, | |
"prev_output_tokens": prev_output_tokens, | |
}, | |
"speaker": speaker, | |
"target": traget, | |
"durations": durations, | |
"pitches": pitches, | |
"energies": energies, | |
"target_lengths": target_lengths, | |
"ntokens": sum(target_lengths).item(), | |
"nsentences": len(samples), | |
"src_texts": src_texts, | |
} | |
class Text2UnitDatasetCreator(SpeechToTextDatasetCreator): | |
KEY_DURATION = "duration" | |
KEY_PITCH = "pitch" | |
KEY_ENERGY = "energy" | |
KEY_UNIT = "unit" | |
def _from_list( | |
cls, | |
split_name: str, | |
is_train_split, | |
samples: List[Dict], | |
cfg: S2TDataConfig, | |
tgt_dict, | |
pre_tokenizer, | |
bpe_tokenizer, | |
n_frames_per_step, | |
speaker_to_id, | |
) -> Text2UnitDataset: | |
audio_root = Path(cfg.audio_root) | |
ids = [s[cls.KEY_ID] for s in samples] | |
# audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] | |
unit_labels = [s[cls.KEY_UNIT] for s in samples] | |
unit_labels = [ | |
None if dd is None else [int(d) for d in dd.split(" ")] for dd in unit_labels | |
] | |
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] | |
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] | |
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] | |
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] | |
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] | |
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] | |
durations = [s.get(cls.KEY_DURATION, None) for s in samples] | |
durations = [ | |
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations | |
] | |
durations = None if any(dd is None for dd in durations) else durations | |
pitches = [s.get(cls.KEY_PITCH, None) for s in samples] | |
pitches = [ | |
None if pp is None else (audio_root / pp).as_posix() for pp in pitches | |
] | |
pitches = None if any(pp is None for pp in pitches) else pitches | |
energies = [s.get(cls.KEY_ENERGY, None) for s in samples] | |
energies = [ | |
None if ee is None else (audio_root / ee).as_posix() for ee in energies | |
] | |
energies = None if any(ee is None for ee in energies) else energies | |
return Text2UnitDataset( | |
split_name, | |
is_train_split, | |
cfg, | |
unit_labels, | |
n_frames, | |
src_texts, | |
tgt_texts, | |
speakers, | |
src_langs, | |
tgt_langs, | |
ids, | |
tgt_dict, | |
pre_tokenizer, | |
bpe_tokenizer, | |
n_frames_per_step, | |
speaker_to_id, | |
durations, | |
pitches, | |
energies, | |
) | |