victan's picture
Upload seamless_communication/datasets/huggingface.py with huggingface_hub
2cf8f5b
raw
history blame
No virus
4.58 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
import logging
import os
from abc import abstractmethod
from typing import Dict, Iterable, Optional
import numpy as np
import torch
from datasets import load_dataset
from .datatypes import LangPairSample, MultimodalSample
logger = logging.getLogger(__name__)
class SpeechTokenizer:
@abstractmethod
def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
...
class Speech2SpeechFleursDatasetBuilder:
"""Assembles speech2speech dataset from google/fleurs on HuggingFace"""
HF_FLEURS_DATASET_NAME = "google/fleurs"
def __init__(
self,
source_lang: str,
target_lang: str,
split: str = "test",
skip_source_audio: bool = True,
skip_target_audio: bool = True,
audio_dtype: torch.dtype = torch.float32,
dataset_cache_dir: Optional[str] = None,
speech_tokenizer: Optional[SpeechTokenizer] = None,
):
self.source_lang = source_lang
self.target_lang = target_lang
self.split = split
self.dataset_cache_dir = dataset_cache_dir
self.audio_dtype = audio_dtype
self.skip_source_audio = skip_source_audio
self.skip_target_audio = skip_target_audio
self.speech_tokenizer = speech_tokenizer
def _prepare_sample(
self,
sample_id: int,
lang: str,
text: str,
audio_local_path: Optional[str] = None,
waveform_npy: Optional[np.ndarray] = None,
sampling_rate: Optional[int] = None,
) -> MultimodalSample:
should_skip_audio = (
lang == self.target_lang
and self.skip_target_audio
or lang == self.source_lang
and self.skip_source_audio
or waveform_npy is None
)
if not should_skip_audio:
waveform = torch.from_numpy(waveform_npy).to(self.audio_dtype)
else:
waveform = None
if self.speech_tokenizer is not None and not should_skip_audio:
assert waveform is not None
assert sampling_rate is not None
units_tensor = self.speech_tokenizer.encode(
waveform, sampling_rate
).reshape(-1)
units = units_tensor.tolist()
else:
units = None
return MultimodalSample(
id=sample_id,
lang=lang,
text=text.strip(),
audio_local_path=audio_local_path,
waveform=waveform,
sampling_rate=sampling_rate,
units=units,
)
def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]:
ds = load_dataset(
self.HF_FLEURS_DATASET_NAME,
lang,
split=self.split,
cache_dir=self.dataset_cache_dir,
streaming=False,
)
for item in ds:
audio_path = os.path.join(
os.path.dirname(item["path"]), item["audio"]["path"]
)
(sample_id, audio_local_path, waveform, sampling_rate, text) = (
item["id"],
audio_path,
item["audio"]["array"],
item["audio"]["sampling_rate"],
item["transcription"],
)
yield self._prepare_sample(
sample_id=sample_id,
audio_local_path=audio_local_path,
waveform_npy=waveform,
sampling_rate=sampling_rate,
text=text,
lang=lang,
)
def __iter__(self) -> Iterable[LangPairSample]:
logger.info(f"Loading {self.target_lang} samples")
target_samples: Dict[int, MultimodalSample] = {}
for idx, sample in enumerate(
self.iterate_lang_audio_samples(lang=self.target_lang)
):
if idx and idx % 100 == 0:
logger.info(f"..loaded {idx} target samples")
target_samples[sample.id] = sample
logger.info(f"Loading {self.source_lang} samples")
for idx, sample in enumerate(
self.iterate_lang_audio_samples(lang=self.source_lang)
):
if idx and idx % 100 == 0:
logger.info(f"..loaded {idx} source samples")
if sample.id in target_samples:
yield LangPairSample(source=sample, target=target_samples[sample.id])