Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT | |
| # except for the third-party components listed below. | |
| # Hunyuan 3D does not impose any additional limitations beyond what is outlined | |
| # in the repsective licenses of these third-party components. | |
| # Users must comply with all terms and conditions of original licenses of these third-party | |
| # components and must ensure that the usage of the third party components adheres to | |
| # all relevant laws and regulations. | |
| # For avoidance of doubts, Hunyuan 3D means the large language models and | |
| # their software and algorithms, including trained model weights, parameters (including | |
| # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, | |
| # fine-tuning enabling code and other elements of the foregoing made publicly available | |
| # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. | |
| import pytorch_lightning as pl | |
| from torch.utils.data import Dataset, ConcatDataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| class DataModuleFromConfig(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| batch_size=8, | |
| num_workers=4, | |
| train=None, | |
| validation=None, | |
| test=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.dataset_configs = dict() | |
| if train is not None: | |
| self.dataset_configs["train"] = train | |
| if validation is not None: | |
| self.dataset_configs["validation"] = validation | |
| if test is not None: | |
| self.dataset_configs["test"] = test | |
| def setup(self, stage): | |
| from src.utils.train_util import instantiate_from_config | |
| if stage in ["fit"]: | |
| dataset_dict = {} | |
| for k in self.dataset_configs: | |
| dataset_dict[k] = [] | |
| for loader in self.dataset_configs[k]: | |
| dataset_dict[k].append(instantiate_from_config(loader)) | |
| self.datasets = dataset_dict | |
| print(self.datasets) | |
| else: | |
| raise NotImplementedError | |
| def train_dataloader(self): | |
| datasets = ConcatDataset(self.datasets["train"]) | |
| sampler = DistributedSampler(datasets) | |
| return DataLoader( | |
| datasets, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| shuffle=False, | |
| sampler=sampler, | |
| prefetch_factor=2, | |
| pin_memory=True, | |
| ) | |
| def val_dataloader(self): | |
| datasets = ConcatDataset(self.datasets["validation"]) | |
| sampler = DistributedSampler(datasets) | |
| return DataLoader(datasets, batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler) | |
| def test_dataloader(self): | |
| datasets = ConcatDataset(self.datasets["test"]) | |
| return DataLoader(datasets, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |