AudioSep / data /datamodules.py
badayvedat's picture
Initial commit
ae29df4
from typing import Dict, List, Optional, NoReturn
import torch
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from data.audiotext_dataset import AudioTextDataset
class DataModule(pl.LightningDataModule):
def __init__(
self,
train_dataset: object,
batch_size: int,
num_workers: int
):
r"""Data module. To get one batch of data:
code-block:: python
data_module.setup()
for batch_data_dict in data_module.train_dataloader():
print(batch_data_dict.keys())
break
Args:
train_sampler: Sampler object
train_dataset: Dataset object
num_workers: int
distributed: bool
"""
super().__init__()
self._train_dataset = train_dataset
self.num_workers = num_workers
self.batch_size = batch_size
self.collate_fn = collate_fn
def prepare_data(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
pass
def setup(self, stage: Optional[str] = None) -> NoReturn:
r"""called on every device."""
# make assignments here (val/train/test split)
# called on every process in DDP
# SegmentSampler is used for selecting segments for training.
# On multiple devices, each SegmentSampler samples a part of mini-batch
# data.
self.train_dataset = self._train_dataset
def train_dataloader(self) -> torch.utils.data.DataLoader:
r"""Get train loader."""
train_loader = DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=False,
shuffle=True
)
return train_loader
def val_dataloader(self):
# val_split = Dataset(...)
# return DataLoader(val_split)
pass
def test_dataloader(self):
# test_split = Dataset(...)
# return DataLoader(test_split)
pass
def teardown(self):
# clean up after fit or test
# called on every process in DDP
pass
def collate_fn(list_data_dict):
r"""Collate mini-batch data to inputs and targets for training.
Args:
list_data_dict: e.g., [
{
'text': 'a sound of dog',
'waveform': (1, samples),
'modality': 'audio_text'
}
...
]
Returns:
data_dict: e.g.
'audio_text': {
'text': ['a sound of dog', ...]
'waveform': (batch_size, 1, samples)
}
"""
at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text']
at_data_dict = {}
if len(at_list_data_dict) > 0:
for key in at_list_data_dict[0].keys():
at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict]
if key == 'waveform':
at_data_dict[key] = torch.stack(at_data_dict[key])
elif key == 'text':
at_data_dict[key] = [text for text in at_data_dict[key]]
data_dict = {
'audio_text': at_data_dict
}
return data_dict