|
|
|
from pytorch_lightning import LightningDataModule |
|
from AR.data.bucket_sampler import DistributedBucketSampler |
|
from AR.data.dataset import Text2SemanticDataset |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class Text2SemanticDataModule(LightningDataModule): |
|
def __init__( |
|
self, |
|
config, |
|
train_semantic_path, |
|
train_phoneme_path, |
|
dev_semantic_path=None, |
|
dev_phoneme_path=None, |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.train_semantic_path = train_semantic_path |
|
self.train_phoneme_path = train_phoneme_path |
|
self.dev_semantic_path = dev_semantic_path |
|
self.dev_phoneme_path = dev_phoneme_path |
|
self.num_workers = self.config["data"]["num_workers"] |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def setup(self, stage=None, output_logs=False): |
|
self._train_dataset = Text2SemanticDataset( |
|
phoneme_path=self.train_phoneme_path, |
|
semantic_path=self.train_semantic_path, |
|
max_sec=self.config["data"]["max_sec"], |
|
pad_val=self.config["data"]["pad_val"], |
|
) |
|
self._dev_dataset = self._train_dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_dataloader(self): |
|
batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1) |
|
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) |
|
return DataLoader( |
|
self._train_dataset, |
|
batch_size=batch_size, |
|
sampler=sampler, |
|
collate_fn=self._train_dataset.collate, |
|
num_workers=self.num_workers, |
|
persistent_workers=True, |
|
prefetch_factor=16, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self._dev_dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
collate_fn=self._train_dataset.collate, |
|
num_workers=max(self.num_workers, 12), |
|
persistent_workers=True, |
|
prefetch_factor=16, |
|
) |
|
|
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
self._dev_dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
collate_fn=self._train_dataset.collate, |
|
) |
|
|