# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py # reference: https://github.com/lifeiteng/vall-e from pytorch_lightning import LightningDataModule from AR.data.bucket_sampler import DistributedBucketSampler from AR.data.dataset import Text2SemanticDataset from torch.utils.data import DataLoader class Text2SemanticDataModule(LightningDataModule): def __init__( self, config, train_semantic_path, train_phoneme_path, dev_semantic_path=None, dev_phoneme_path=None, ): super().__init__() self.config = config self.train_semantic_path = train_semantic_path self.train_phoneme_path = train_phoneme_path self.dev_semantic_path = dev_semantic_path self.dev_phoneme_path = dev_phoneme_path self.num_workers = self.config["data"]["num_workers"] def prepare_data(self): pass def setup(self, stage=None, output_logs=False): self._train_dataset = Text2SemanticDataset( phoneme_path=self.train_phoneme_path, semantic_path=self.train_semantic_path, max_sec=self.config["data"]["max_sec"], pad_val=self.config["data"]["pad_val"], ) self._dev_dataset = self._train_dataset # self._dev_dataset = Text2SemanticDataset( # phoneme_path=self.dev_phoneme_path, # semantic_path=self.dev_semantic_path, # max_sample=self.config['data']['max_eval_sample'], # max_sec=self.config['data']['max_sec'], # pad_val=self.config['data']['pad_val']) def train_dataloader(self): batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"] batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存 sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) return DataLoader( self._train_dataset, batch_size=batch_size, sampler=sampler, collate_fn=self._train_dataset.collate, num_workers=self.num_workers, persistent_workers=True, prefetch_factor=16, ) def val_dataloader(self): return DataLoader( self._dev_dataset, batch_size=1, shuffle=False, collate_fn=self._train_dataset.collate, num_workers=max(self.num_workers, 12), persistent_workers=True, prefetch_factor=16, ) # 这个会使用到嘛? def test_dataloader(self): return DataLoader( self._dev_dataset, batch_size=1, shuffle=False, collate_fn=self._train_dataset.collate, )