|
from typing import Dict, List, Optional, NoReturn |
|
import torch |
|
import lightning.pytorch as pl |
|
from torch.utils.data import DataLoader |
|
from data.audiotext_dataset import AudioTextDataset |
|
|
|
|
|
class DataModule(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
train_dataset: object, |
|
batch_size: int, |
|
num_workers: int |
|
): |
|
r"""Data module. To get one batch of data: |
|
|
|
code-block:: python |
|
|
|
data_module.setup() |
|
|
|
for batch_data_dict in data_module.train_dataloader(): |
|
print(batch_data_dict.keys()) |
|
break |
|
|
|
Args: |
|
train_sampler: Sampler object |
|
train_dataset: Dataset object |
|
num_workers: int |
|
distributed: bool |
|
""" |
|
super().__init__() |
|
self._train_dataset = train_dataset |
|
self.num_workers = num_workers |
|
self.batch_size = batch_size |
|
self.collate_fn = collate_fn |
|
|
|
|
|
def prepare_data(self): |
|
|
|
|
|
pass |
|
|
|
def setup(self, stage: Optional[str] = None) -> NoReturn: |
|
r"""called on every device.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.train_dataset = self._train_dataset |
|
|
|
|
|
def train_dataloader(self) -> torch.utils.data.DataLoader: |
|
r"""Get train loader.""" |
|
train_loader = DataLoader( |
|
dataset=self.train_dataset, |
|
batch_size=self.batch_size, |
|
collate_fn=self.collate_fn, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
persistent_workers=False, |
|
shuffle=True |
|
) |
|
|
|
return train_loader |
|
|
|
def val_dataloader(self): |
|
|
|
|
|
pass |
|
|
|
def test_dataloader(self): |
|
|
|
|
|
pass |
|
|
|
def teardown(self): |
|
|
|
|
|
pass |
|
|
|
|
|
def collate_fn(list_data_dict): |
|
r"""Collate mini-batch data to inputs and targets for training. |
|
|
|
Args: |
|
list_data_dict: e.g., [ |
|
{ |
|
'text': 'a sound of dog', |
|
'waveform': (1, samples), |
|
'modality': 'audio_text' |
|
} |
|
... |
|
] |
|
Returns: |
|
data_dict: e.g. |
|
'audio_text': { |
|
'text': ['a sound of dog', ...] |
|
'waveform': (batch_size, 1, samples) |
|
} |
|
""" |
|
|
|
at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text'] |
|
|
|
at_data_dict = {} |
|
|
|
if len(at_list_data_dict) > 0: |
|
for key in at_list_data_dict[0].keys(): |
|
at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict] |
|
if key == 'waveform': |
|
at_data_dict[key] = torch.stack(at_data_dict[key]) |
|
elif key == 'text': |
|
at_data_dict[key] = [text for text in at_data_dict[key]] |
|
|
|
|
|
data_dict = { |
|
'audio_text': at_data_dict |
|
} |
|
|
|
return data_dict |