Source code for datasets.features.audio

from dataclasses import dataclass, field
from io import BytesIO
from typing import Any, ClassVar, Optional, Union

import pyarrow as pa

from ..table import array_cast
from ..utils.streaming_download_manager import xopen


[docs]@dataclass class Audio: """Audio Feature to extract audio data from an audio file. Input: The Audio feature accepts as input: - A :obj:`str`: Absolute path to the audio file (i.e. random access is allowed). - A :obj:`dict` with the keys: - path: String with relative path of the audio file to the archive file. - bytes: Bytes content of the audio file. This is useful for archived files with sequential access. - A :obj:`dict` with the keys: - path: String with relative path of the audio file to the archive file. - array: Array containing the audio sample - sampling_rate: Integer corresponding to the samping rate of the audio sample. This is useful for archived files with sequential access. Args: sampling_rate (:obj:`int`, optional): Target sampling rate. If `None`, the native sampling rate is used. mono (:obj:`bool`, default ``True``): Whether to convert the audio signal to mono by averaging samples across channels. decode (:obj:`bool`, default ``True``): Whether to decode the audio data. If `False`, returns the underlying dictionary in the format {"path": audio_path, "bytes": audio_bytes}. """ sampling_rate: Optional[int] = None mono: bool = True decode: bool = True id: Optional[str] = None # Automatically constructed dtype: ClassVar[str] = "dict" pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) _type: str = field(default="Audio", init=False, repr=False) def __call__(self): return self.pa_type
[docs] def encode_example(self, value: Union[str, dict]) -> dict: """Encode example into a format for Arrow. Args: value (:obj:`str` or :obj:`dict`): Data passed as input to Audio feature. Returns: :obj:`dict` """ try: import soundfile as sf # soundfile is a dependency of librosa, needed to decode audio files. except ImportError as err: raise ImportError("To support encoding audio data, please install 'soundfile'.") from err if isinstance(value, str): return {"bytes": None, "path": value} elif isinstance(value, dict) and "array" in value: buffer = BytesIO() sf.write(buffer, value["array"], value["sampling_rate"]) return {"bytes": buffer.getvalue(), "path": value.get("path")} elif value.get("bytes") is not None or value.get("path") is not None: return {"bytes": value.get("bytes"), "path": value.get("path")} else: raise ValueError( f"An audio sample should have one of 'path' or 'bytes' but they are missing or None in {value}." )
[docs] def decode_example(self, value: dict) -> dict: """Decode example audio file into audio data. Args: value (:obj:`dict`): a dictionary with keys: - path: String with relative audio file path. - bytes: Bytes of the audio file. Returns: dict """ if not self.decode: raise RuntimeError("Decoding is disabled for this feature. Please use Audio(decode=True) instead.") path, file = (value["path"], BytesIO(value["bytes"])) if value["bytes"] is not None else (value["path"], None) if path is None and file is None: raise ValueError(f"An audio sample should have one of 'path' or 'bytes' but both are None in {value}.") elif path is not None and path.endswith("mp3"): array, sampling_rate = self._decode_mp3(file if file else path) else: if file: array, sampling_rate = self._decode_non_mp3_file_like(file) else: array, sampling_rate = self._decode_non_mp3_path_like(path) return {"path": path, "array": array, "sampling_rate": sampling_rate}
[docs] def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.StructArray: """Cast an Arrow array to the Audio arrow storage type. The Arrow types that can be converted to the Audio pyarrow storage type are: - pa.string() - it must contain the "path" data - pa.struct({"bytes": pa.binary()}) - pa.struct({"path": pa.string()}) - pa.struct({"bytes": pa.binary(), "path": pa.string()}) - order doesn't matter Args: storage (Union[pa.StringArray, pa.StructArray]): [description] Returns: pa.StructArray: Array in the Audio arrow storage type, that is pa.struct({"bytes": pa.binary(), "path": pa.string()}) """ if pa.types.is_string(storage.type): bytes_array = pa.array([None] * len(storage), type=pa.binary()) storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"]) elif pa.types.is_struct(storage.type) and storage.type.get_all_field_indices("array"): storage = pa.array([Audio().encode_example(x) for x in storage.to_pylist()]) elif pa.types.is_struct(storage.type): if storage.type.get_field_index("bytes") >= 0: bytes_array = storage.field("bytes") else: bytes_array = pa.array([None] * len(storage), type=pa.binary()) if storage.type.get_field_index("path") >= 0: path_array = storage.field("path") else: path_array = pa.array([None] * len(storage), type=pa.string()) storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"]) return array_cast(storage, self.pa_type)
def _decode_non_mp3_path_like(self, path): try: import librosa except ImportError as err: raise ImportError("To support decoding audio files, please install 'librosa'.") from err with xopen(path, "rb") as f: array, sampling_rate = librosa.load(f, sr=self.sampling_rate, mono=self.mono) return array, sampling_rate def _decode_non_mp3_file_like(self, file): try: import librosa import soundfile as sf except ImportError as err: raise ImportError("To support decoding audio files, please install 'librosa' and 'soundfile'.") from err array, sampling_rate = sf.read(file) array = array.T if self.mono: array = librosa.to_mono(array) if self.sampling_rate and self.sampling_rate != sampling_rate: array = librosa.resample(array, sampling_rate, self.sampling_rate, res_type="kaiser_best") sampling_rate = self.sampling_rate return array, sampling_rate def _decode_mp3(self, path_or_file): try: import torchaudio import torchaudio.transforms as T except ImportError as err: raise ImportError("To support decoding 'mp3' audio files, please install 'torchaudio'.") from err try: torchaudio.set_audio_backend("sox_io") except RuntimeError as err: raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err array, sampling_rate = torchaudio.load(path_or_file, format="mp3") if self.sampling_rate and self.sampling_rate != sampling_rate: if not hasattr(self, "_resampler") or self._resampler.orig_freq != sampling_rate: self._resampler = T.Resample(sampling_rate, self.sampling_rate) array = self._resampler(array) sampling_rate = self.sampling_rate array = array.numpy() if self.mono: array = array.mean(axis=0) return array, sampling_rate