artst-demo-asr / SpeechT5 /SpeechLM /speechlm /data /text_to_unit_dataset.py
amupd's picture
SpeechT5 upload
62e9ca6
raw
history blame
10.9 kB
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
import numpy as np
import torch
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
SpeechToTextDatasetCreator,
S2TDataConfig,
_collate_frames,
get_features_or_waveform,
)
from fairseq.data import Dictionary, data_utils as fairseq_data_utils
@dataclass
class TextToUnitDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
speaker_emb: Optional[torch.Tensor] = None
duration: Optional[torch.Tensor] = None
pitch: Optional[torch.Tensor] = None
energy: Optional[torch.Tensor] = None
class Text2UnitDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
unit_labels: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None,
):
super(Text2UnitDataset, self).__init__(
split,
is_train_split,
cfg,
unit_labels,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.durations = durations
self.pitches = pitches
self.energies = energies
self.unit_labels = unit_labels
self.feature_root = Path(cfg.audio_root)
self.spk_emb_type = cfg.config.get("speaker_embedding_type", None)
self.random_spk = cfg.config.get("random_speaker", False)
if self.spk_emb_type is not None:
self.spk_emb_choices = [i for i in (self.feature_root / self.spk_emb_type).glob("*.npy")]
self.spk_emb_num = len(self.spk_emb_choices)
def __getitem__(self, index: int) -> TextToUnitDatasetItem:
# s2t_item = super().__getitem__(index)
source = torch.LongTensor(self.unit_labels[index])
target = None
if self.tgt_texts is not None:
tokenized = self.get_tokenized_tgt_text(index)
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=self.append_eos
).long()
if self.cfg.prepend_tgt_lang_tag:
lang_tag_idx = self.get_lang_tag_idx(
self.tgt_langs[index], self.tgt_dict
)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
speaker_id = None
if self.speaker_to_id is not None:
speaker_id = self.speaker_to_id[self.speakers[index]]
speaker_emb = None
if self.spk_emb_type is not None:
if self.random_spk:
spk_emb_path = self.spk_emb_choices[np.random.choice(self.spk_emb_num)]
else:
spk_emb_path = self.feature_root / self.spk_emb_type / f"{self.ids[index]}.npy"
speaker_emb = get_features_or_waveform(spk_emb_path)
speaker_emb = torch.from_numpy(speaker_emb).float()
duration, pitch, energy = None, None, None
if self.durations is not None:
duration = torch.tensor(
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
)
if self.pitches is not None:
pitch = get_features_or_waveform(self.pitches[index])
pitch = torch.from_numpy(
np.concatenate((pitch, [0])) # pad 0 for EOS
).float()
if self.energies is not None:
energy = get_features_or_waveform(self.energies[index])
energy = torch.from_numpy(
np.concatenate((energy, [0])) # pad 0 for EOS
).float()
return TextToUnitDatasetItem(
index=index,
source=source,
target=target,
speaker_id=speaker_id,
speaker_emb=speaker_emb,
duration=duration,
pitch=pitch,
energy=energy,
)
def collater(self, samples: List[TextToUnitDatasetItem]) -> Dict[str, Any]:
if len(samples) == 0:
return {}
src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True)
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
0, order
)
traget = fairseq_data_utils.collate_tokens(
[s.source for s in samples],
self.tgt_dict.pad(),
).index_select(0, order)
target_lengths = torch.tensor(
[s.source.shape[0] for s in samples], dtype=torch.long
).index_select(0, order)
src_tokens = fairseq_data_utils.collate_tokens(
[s.target for s in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
).index_select(0, order)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
if self.spk_emb_type is not None:
speaker = torch.stack([s.speaker_emb for s in samples], dim=0).index_select(0, order)
bsz, _ = traget.size()
prev_output_tokens = torch.cat(
(traget.new_zeros((bsz, self.tgt_dict.bos())), traget[:, :-1]), dim=1
)
durations, pitches, energies = None, None, None
if self.durations is not None:
durations = fairseq_data_utils.collate_tokens(
[s.duration for s in samples], 0
).index_select(0, order)
assert src_tokens.shape[1] == durations.shape[1]
if self.pitches is not None:
pitches = _collate_frames([s.pitch for s in samples], True)
pitches = pitches.index_select(0, order)
assert src_tokens.shape[1] == pitches.shape[1]
if self.energies is not None:
energies = _collate_frames([s.energy for s in samples], True)
energies = energies.index_select(0, order)
assert src_tokens.shape[1] == energies.shape[1]
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
return {
"id": id_,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens,
},
"speaker": speaker,
"target": traget,
"durations": durations,
"pitches": pitches,
"energies": energies,
"target_lengths": target_lengths,
"ntokens": sum(target_lengths).item(),
"nsentences": len(samples),
"src_texts": src_texts,
}
class Text2UnitDatasetCreator(SpeechToTextDatasetCreator):
KEY_DURATION = "duration"
KEY_PITCH = "pitch"
KEY_ENERGY = "energy"
KEY_UNIT = "unit"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
) -> Text2UnitDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
# audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
unit_labels = [s[cls.KEY_UNIT] for s in samples]
unit_labels = [
None if dd is None else [int(d) for d in dd.split(" ")] for dd in unit_labels
]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
]
durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
]
pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [
None if ee is None else (audio_root / ee).as_posix() for ee in energies
]
energies = None if any(ee is None for ee in energies) else energies
return Text2UnitDataset(
split_name,
is_train_split,
cfg,
unit_labels,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
durations,
pitches,
energies,
)