|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from abc import abstractmethod |
|
from typing import Dict, Iterable, Optional |
|
|
|
import numpy as np |
|
import torch |
|
from datasets import load_dataset |
|
|
|
from .datatypes import LangPairSample, MultimodalSample |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SpeechTokenizer: |
|
@abstractmethod |
|
def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: |
|
... |
|
|
|
|
|
class Speech2SpeechFleursDatasetBuilder: |
|
"""Assembles speech2speech dataset from google/fleurs on HuggingFace""" |
|
|
|
HF_FLEURS_DATASET_NAME = "google/fleurs" |
|
|
|
def __init__( |
|
self, |
|
source_lang: str, |
|
target_lang: str, |
|
split: str = "test", |
|
skip_source_audio: bool = True, |
|
skip_target_audio: bool = True, |
|
audio_dtype: torch.dtype = torch.float32, |
|
dataset_cache_dir: Optional[str] = None, |
|
speech_tokenizer: Optional[SpeechTokenizer] = None, |
|
): |
|
self.source_lang = source_lang |
|
self.target_lang = target_lang |
|
self.split = split |
|
self.dataset_cache_dir = dataset_cache_dir |
|
self.audio_dtype = audio_dtype |
|
self.skip_source_audio = skip_source_audio |
|
self.skip_target_audio = skip_target_audio |
|
self.speech_tokenizer = speech_tokenizer |
|
|
|
def _prepare_sample( |
|
self, |
|
sample_id: int, |
|
lang: str, |
|
text: str, |
|
audio_local_path: Optional[str] = None, |
|
waveform_npy: Optional[np.ndarray] = None, |
|
sampling_rate: Optional[int] = None, |
|
) -> MultimodalSample: |
|
should_skip_audio = ( |
|
lang == self.target_lang |
|
and self.skip_target_audio |
|
or lang == self.source_lang |
|
and self.skip_source_audio |
|
or waveform_npy is None |
|
) |
|
if not should_skip_audio: |
|
waveform = torch.from_numpy(waveform_npy).to(self.audio_dtype) |
|
else: |
|
waveform = None |
|
if self.speech_tokenizer is not None and not should_skip_audio: |
|
assert waveform is not None |
|
assert sampling_rate is not None |
|
units_tensor = self.speech_tokenizer.encode( |
|
waveform, sampling_rate |
|
).reshape(-1) |
|
units = units_tensor.tolist() |
|
else: |
|
units = None |
|
return MultimodalSample( |
|
id=sample_id, |
|
lang=lang, |
|
text=text.strip(), |
|
audio_local_path=audio_local_path, |
|
waveform=waveform, |
|
sampling_rate=sampling_rate, |
|
units=units, |
|
) |
|
|
|
def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]: |
|
ds = load_dataset( |
|
self.HF_FLEURS_DATASET_NAME, |
|
lang, |
|
split=self.split, |
|
cache_dir=self.dataset_cache_dir, |
|
streaming=False, |
|
) |
|
for item in ds: |
|
audio_path = os.path.join( |
|
os.path.dirname(item["path"]), item["audio"]["path"] |
|
) |
|
(sample_id, audio_local_path, waveform, sampling_rate, text) = ( |
|
item["id"], |
|
audio_path, |
|
item["audio"]["array"], |
|
item["audio"]["sampling_rate"], |
|
item["transcription"], |
|
) |
|
yield self._prepare_sample( |
|
sample_id=sample_id, |
|
audio_local_path=audio_local_path, |
|
waveform_npy=waveform, |
|
sampling_rate=sampling_rate, |
|
text=text, |
|
lang=lang, |
|
) |
|
|
|
def __iter__(self) -> Iterable[LangPairSample]: |
|
logger.info(f"Loading {self.target_lang} samples") |
|
target_samples: Dict[int, MultimodalSample] = {} |
|
for idx, sample in enumerate( |
|
self.iterate_lang_audio_samples(lang=self.target_lang) |
|
): |
|
if idx and idx % 100 == 0: |
|
logger.info(f"..loaded {idx} target samples") |
|
target_samples[sample.id] = sample |
|
|
|
logger.info(f"Loading {self.source_lang} samples") |
|
for idx, sample in enumerate( |
|
self.iterate_lang_audio_samples(lang=self.source_lang) |
|
): |
|
if idx and idx % 100 == 0: |
|
logger.info(f"..loaded {idx} source samples") |
|
if sample.id in target_samples: |
|
yield LangPairSample(source=sample, target=target_samples[sample.id]) |
|
|