MusicGen2 / audiocraft /data /info_audio_dataset.py
reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Base classes for the datasets that also provide non-audio metadata,
e.g. description, text transcription etc.
"""
from dataclasses import dataclass
import logging
import math
import re
import typing as tp
import torch
from .audio_dataset import AudioDataset, AudioMeta
from ..environment import AudioCraftEnvironment
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
logger = logging.getLogger(__name__)
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
"""Monkey-patch meta to match cluster specificities."""
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
if meta.info_path is not None:
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
return meta
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
"""Monkey-patch all meta to match cluster specificities."""
return [_clusterify_meta(m) for m in meta]
@dataclass
class AudioInfo(SegmentWithAttributes):
"""Dummy SegmentInfo with empty attributes.
The InfoAudioDataset is expected to return metadata that inherits
from SegmentWithAttributes class and can return conditioning attributes.
This basically guarantees all datasets will be compatible with current
solver that contain conditioners requiring this.
"""
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
def to_condition_attributes(self) -> ConditioningAttributes:
return ConditioningAttributes()
class InfoAudioDataset(AudioDataset):
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
"""
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
super().__init__(clusterify_all_meta(meta), **kwargs)
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
if not self.return_info:
wav = super().__getitem__(index)
assert isinstance(wav, torch.Tensor)
return wav
wav, meta = super().__getitem__(index)
return wav, AudioInfo(**meta.to_dict())
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
"""Preprocess a single keyword or possible a list of keywords."""
if isinstance(value, list):
return get_keyword_list(value)
else:
return get_keyword(value)
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess a single keyword."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
else:
return value.strip()
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess a single keyword."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
else:
return value.strip().lower()
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
"""Preprocess a list of keywords."""
if isinstance(values, str):
values = [v.strip() for v in re.split(r'[,\s]', values)]
elif isinstance(values, float) and math.isnan(values):
values = []
if not isinstance(values, list):
logger.debug(f"Unexpected keyword list {values}")
values = [str(values)]
kws = [get_keyword(v) for v in values]
kw_list = [k for k in kws if k is not None]
if len(kw_list) == 0:
return None
else:
return kw_list