Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Base classes for the datasets that also provide non-audio metadata, | |
e.g. description, text transcription etc. | |
""" | |
from dataclasses import dataclass | |
import logging | |
import math | |
import re | |
import typing as tp | |
import torch | |
from .audio_dataset import AudioDataset, AudioMeta | |
from ..environment import AudioCraftEnvironment | |
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes | |
logger = logging.getLogger(__name__) | |
def _clusterify_meta(meta: AudioMeta) -> AudioMeta: | |
"""Monkey-patch meta to match cluster specificities.""" | |
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) | |
if meta.info_path is not None: | |
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) | |
return meta | |
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: | |
"""Monkey-patch all meta to match cluster specificities.""" | |
return [_clusterify_meta(m) for m in meta] | |
class AudioInfo(SegmentWithAttributes): | |
"""Dummy SegmentInfo with empty attributes. | |
The InfoAudioDataset is expected to return metadata that inherits | |
from SegmentWithAttributes class and can return conditioning attributes. | |
This basically guarantees all datasets will be compatible with current | |
solver that contain conditioners requiring this. | |
""" | |
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM. | |
def to_condition_attributes(self) -> ConditioningAttributes: | |
return ConditioningAttributes() | |
class InfoAudioDataset(AudioDataset): | |
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. | |
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. | |
""" | |
def __init__(self, meta: tp.List[AudioMeta], **kwargs): | |
super().__init__(clusterify_all_meta(meta), **kwargs) | |
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: | |
if not self.return_info: | |
wav = super().__getitem__(index) | |
assert isinstance(wav, torch.Tensor) | |
return wav | |
wav, meta = super().__getitem__(index) | |
return wav, AudioInfo(**meta.to_dict()) | |
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: | |
"""Preprocess a single keyword or possible a list of keywords.""" | |
if isinstance(value, list): | |
return get_keyword_list(value) | |
else: | |
return get_keyword(value) | |
def get_string(value: tp.Optional[str]) -> tp.Optional[str]: | |
"""Preprocess a single keyword.""" | |
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': | |
return None | |
else: | |
return value.strip() | |
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: | |
"""Preprocess a single keyword.""" | |
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': | |
return None | |
else: | |
return value.strip().lower() | |
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: | |
"""Preprocess a list of keywords.""" | |
if isinstance(values, str): | |
values = [v.strip() for v in re.split(r'[,\s]', values)] | |
elif isinstance(values, float) and math.isnan(values): | |
values = [] | |
if not isinstance(values, list): | |
logger.debug(f"Unexpected keyword list {values}") | |
values = [str(values)] | |
kws = [get_keyword(v) for v in values] | |
kw_list = [k for k in kws if k is not None] | |
if len(kw_list) == 0: | |
return None | |
else: | |
return kw_list | |