Spaces:
Running
on
Zero
Running
on
Zero
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| import lightning.pytorch as pl | |
| # from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | |
| import torch | |
| from torch import LongTensor | |
| from torch.utils import data | |
| from torch.utils.data import DataLoader, Dataset | |
| from typing import Dict, List, Tuple, Union, Callable | |
| import os | |
| import numpy as np | |
| from .raw_data import RawData | |
| from .asset import Asset | |
| from .transform import TransformConfig, transform_asset | |
| from .datapath import DatapathConfig, Datapath | |
| from .spec import ConfigSpec | |
| from ..tokenizer.spec import TokenizerSpec, TokenizerConfig | |
| from ..tokenizer.parse import get_tokenizer | |
| from ..model.spec import ModelInput | |
| class DatasetConfig(ConfigSpec): | |
| ''' | |
| Config to handle dataset format. | |
| ''' | |
| # shuffle dataset | |
| shuffle: bool | |
| # batch size | |
| batch_size: int | |
| # number of workers | |
| num_workers: int | |
| # datapath | |
| datapath_config: DatapathConfig | |
| # use pin memory | |
| pin_memory: bool = True | |
| # use persistent workers | |
| persistent_workers: bool = True | |
| def parse(cls, config) -> 'DatapathConfig': | |
| cls.check_keys(config) | |
| return DatasetConfig( | |
| shuffle=config.shuffle, | |
| batch_size=config.batch_size, | |
| num_workers=config.num_workers, | |
| pin_memory=config.pin_memory, | |
| persistent_workers=config.persistent_workers, | |
| datapath_config=DatapathConfig.parse(config.datapath_config), | |
| ) | |
| def split_by_cls(self) -> Dict[str, 'DatasetConfig']: | |
| res: Dict[str, DatasetConfig] = {} | |
| datapath_config_dict = self.datapath_config.split_by_cls() | |
| for cls in self.datapath_config.data_path: | |
| res[cls] = deepcopy(self) | |
| res[cls].datapath_config = datapath_config_dict[cls] | |
| return res | |
| class UniRigDatasetModule(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, | |
| predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None, | |
| predict_transform_config: Union[TransformConfig, None]=None, | |
| tokenizer_config: Union[TokenizerConfig, None]=None, | |
| debug: bool=False, | |
| data_name: str='raw_data.npz', | |
| datapath: Union[Datapath, None]=None, | |
| cls: Union[str, None]=None, | |
| ): | |
| super().__init__() | |
| self.process_fn = process_fn | |
| self.predict_dataset_config = predict_dataset_config | |
| self.predict_transform_config = predict_transform_config | |
| self.tokenizer_config = tokenizer_config | |
| self.debug = debug | |
| self.data_name = data_name | |
| if debug: | |
| print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m") | |
| if datapath is not None: | |
| self.train_datapath = None | |
| self.validate_datapath = None | |
| self.predict_datapath = { | |
| cls: deepcopy(datapath), | |
| } | |
| self.predict_dataset_config = { | |
| cls: DatasetConfig( | |
| shuffle=False, | |
| batch_size=1, | |
| num_workers=0, | |
| datapath_config=deepcopy(datapath), | |
| pin_memory=False, | |
| persistent_workers=False, | |
| ) | |
| } | |
| else: | |
| # build predict datapath | |
| if self.predict_dataset_config is not None: | |
| self.predict_datapath = { | |
| cls: Datapath(self.predict_dataset_config[cls].datapath_config) | |
| for cls in self.predict_dataset_config | |
| } | |
| else: | |
| self.predict_datapath = None | |
| # get tokenizer | |
| if tokenizer_config is None: | |
| self.tokenizer = None | |
| else: | |
| self.tokenizer = get_tokenizer(config=tokenizer_config) | |
| def prepare_data(self): | |
| pass | |
| def setup(self, stage=None): | |
| if self.predict_datapath is not None: | |
| self._predict_ds = {} | |
| for cls in self.predict_datapath: | |
| self._predict_ds[cls] = UniRigDataset( | |
| process_fn=self.process_fn, | |
| data=self.predict_datapath[cls].get_data(), | |
| name=f"predict-{cls}", | |
| tokenizer=self.tokenizer, | |
| transform_config=self.predict_transform_config, | |
| debug=self.debug, | |
| data_name=self.data_name, | |
| ) | |
| def predict_dataloader(self): | |
| if not hasattr(self, "_predict_ds"): | |
| self.setup() | |
| return self._create_dataloader( | |
| dataset=self._predict_ds, | |
| config=self.predict_dataset_config, | |
| is_train=False, | |
| drop_last=False, | |
| ) | |
| def _create_dataloader( | |
| self, | |
| dataset: Union[Dataset, Dict[str, Dataset]], | |
| config: DatasetConfig, | |
| is_train: bool, | |
| **kwargs, | |
| ) -> Union[DataLoader, Dict[str, DataLoader]]: | |
| def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs): | |
| return DataLoader( | |
| dataset, | |
| batch_size=config.batch_size, | |
| shuffle=config.shuffle, | |
| num_workers=config.num_workers, | |
| pin_memory=config.pin_memory, | |
| persistent_workers=config.persistent_workers, | |
| collate_fn=dataset.collate_fn, | |
| **kwargs, | |
| ) | |
| if isinstance(dataset, Dict): | |
| return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()} | |
| else: | |
| return create_single_dataloader(dataset, config, **kwargs) | |
| class UniRigDataset(Dataset): | |
| def __init__( | |
| self, | |
| data: List[Tuple[str, str]], # (cls, part) | |
| name: str, | |
| process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, | |
| tokenizer: Union[TokenizerSpec, None]=None, | |
| transform_config: Union[TransformConfig, None]=None, | |
| debug: bool=False, | |
| data_name: str='raw_data.npz', | |
| ) -> None: | |
| super().__init__() | |
| self.data = data | |
| self.name = name | |
| self.process_fn = process_fn | |
| self.tokenizer = tokenizer | |
| self.transform_config = transform_config | |
| self.debug = debug | |
| self.data_name = data_name | |
| if not debug: | |
| assert self.process_fn is not None, 'missing data processing function' | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def __getitem__(self, idx) -> ModelInput: | |
| cls, dir_path = self.data[idx] | |
| raw_data = RawData.load(path=os.path.join(dir_path, self.data_name)) | |
| asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name) | |
| first_augments, second_augments = transform_asset( | |
| asset=asset, | |
| transform_config=self.transform_config, | |
| ) | |
| if self.tokenizer is not None and asset.parents is not None: | |
| tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input()) | |
| else: | |
| tokens = None | |
| return ModelInput( | |
| tokens=tokens, | |
| pad=None if self.tokenizer is None else self.tokenizer.pad, | |
| vertices=asset.sampled_vertices.astype(np.float32), | |
| normals=asset.sampled_normals.astype(np.float32), | |
| joints=None if asset.joints is None else asset.joints.astype(np.float32), | |
| tails=None if asset.tails is None else asset.tails.astype(np.float32), | |
| asset=asset, | |
| augments=None, | |
| ) | |
| def _collate_fn_debug(self, batch): | |
| return batch | |
| def _collate_fn(self, batch): | |
| return data.dataloader.default_collate(self.process_fn(batch)) | |
| def collate_fn(self, batch): | |
| if self.debug: | |
| return self._collate_fn_debug(batch) | |
| return self._collate_fn(batch) |