|
|
from copy import deepcopy |
|
|
from dataclasses import dataclass |
|
|
import lightning.pytorch as pl |
|
|
|
|
|
import torch |
|
|
from torch import LongTensor |
|
|
from torch.utils import data |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from typing import Dict, List, Tuple, Union, Callable |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
from .raw_data import RawData |
|
|
from .asset import Asset |
|
|
from .transform import TransformConfig, transform_asset |
|
|
from .datapath import DatapathConfig, Datapath |
|
|
from .spec import ConfigSpec |
|
|
|
|
|
from ..tokenizer.spec import TokenizerSpec, TokenizerConfig |
|
|
from ..tokenizer.parse import get_tokenizer |
|
|
from ..model.spec import ModelInput |
|
|
|
|
|
@dataclass |
|
|
class DatasetConfig(ConfigSpec): |
|
|
''' |
|
|
Config to handle dataset format. |
|
|
''' |
|
|
|
|
|
shuffle: bool |
|
|
|
|
|
|
|
|
batch_size: int |
|
|
|
|
|
|
|
|
num_workers: int |
|
|
|
|
|
|
|
|
datapath_config: DatapathConfig |
|
|
|
|
|
|
|
|
pin_memory: bool = True |
|
|
|
|
|
|
|
|
persistent_workers: bool = True |
|
|
|
|
|
@classmethod |
|
|
def parse(cls, config) -> 'DatapathConfig': |
|
|
cls.check_keys(config) |
|
|
return DatasetConfig( |
|
|
shuffle=config.shuffle, |
|
|
batch_size=config.batch_size, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory, |
|
|
persistent_workers=config.persistent_workers, |
|
|
datapath_config=DatapathConfig.parse(config.datapath_config), |
|
|
) |
|
|
|
|
|
def split_by_cls(self) -> Dict[str, 'DatasetConfig']: |
|
|
res: Dict[str, DatasetConfig] = {} |
|
|
datapath_config_dict = self.datapath_config.split_by_cls() |
|
|
for cls in self.datapath_config.data_path: |
|
|
res[cls] = deepcopy(self) |
|
|
res[cls].datapath_config = datapath_config_dict[cls] |
|
|
return res |
|
|
|
|
|
class UniRigDatasetModule(pl.LightningDataModule): |
|
|
def __init__( |
|
|
self, |
|
|
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, |
|
|
predict_dataset_config: Union[Dict[str, DatasetConfig], None]=None, |
|
|
predict_transform_config: Union[TransformConfig, None]=None, |
|
|
tokenizer_config: Union[TokenizerConfig, None]=None, |
|
|
debug: bool=False, |
|
|
data_name: str='raw_data.npz', |
|
|
datapath: Union[Datapath, None]=None, |
|
|
cls: Union[str, None]=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.process_fn = process_fn |
|
|
self.predict_dataset_config = predict_dataset_config |
|
|
self.predict_transform_config = predict_transform_config |
|
|
self.tokenizer_config = tokenizer_config |
|
|
self.debug = debug |
|
|
self.data_name = data_name |
|
|
|
|
|
if debug: |
|
|
print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m") |
|
|
|
|
|
if datapath is not None: |
|
|
self.train_datapath = None |
|
|
self.validate_datapath = None |
|
|
self.predict_datapath = { |
|
|
cls: deepcopy(datapath), |
|
|
} |
|
|
self.predict_dataset_config = { |
|
|
cls: DatasetConfig( |
|
|
shuffle=False, |
|
|
batch_size=1, |
|
|
num_workers=0, |
|
|
datapath_config=deepcopy(datapath), |
|
|
pin_memory=False, |
|
|
persistent_workers=False, |
|
|
) |
|
|
} |
|
|
else: |
|
|
|
|
|
if self.predict_dataset_config is not None: |
|
|
self.predict_datapath = { |
|
|
cls: Datapath(self.predict_dataset_config[cls].datapath_config) |
|
|
for cls in self.predict_dataset_config |
|
|
} |
|
|
else: |
|
|
self.predict_datapath = None |
|
|
|
|
|
|
|
|
if tokenizer_config is None: |
|
|
self.tokenizer = None |
|
|
else: |
|
|
self.tokenizer = get_tokenizer(config=tokenizer_config) |
|
|
|
|
|
def prepare_data(self): |
|
|
pass |
|
|
|
|
|
def setup(self, stage=None): |
|
|
if self.predict_datapath is not None: |
|
|
self._predict_ds = {} |
|
|
for cls in self.predict_datapath: |
|
|
self._predict_ds[cls] = UniRigDataset( |
|
|
process_fn=self.process_fn, |
|
|
data=self.predict_datapath[cls].get_data(), |
|
|
name=f"predict-{cls}", |
|
|
tokenizer=self.tokenizer, |
|
|
transform_config=self.predict_transform_config, |
|
|
debug=self.debug, |
|
|
data_name=self.data_name, |
|
|
) |
|
|
|
|
|
def predict_dataloader(self): |
|
|
if not hasattr(self, "_predict_ds"): |
|
|
self.setup() |
|
|
return self._create_dataloader( |
|
|
dataset=self._predict_ds, |
|
|
config=self.predict_dataset_config, |
|
|
is_train=False, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
def _create_dataloader( |
|
|
self, |
|
|
dataset: Union[Dataset, Dict[str, Dataset]], |
|
|
config: DatasetConfig, |
|
|
is_train: bool, |
|
|
**kwargs, |
|
|
) -> Union[DataLoader, Dict[str, DataLoader]]: |
|
|
def create_single_dataloader(dataset, config: Union[DatasetConfig, Dict[str, DatasetConfig]], **kwargs): |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=config.shuffle, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory, |
|
|
persistent_workers=config.persistent_workers, |
|
|
collate_fn=dataset.collate_fn, |
|
|
**kwargs, |
|
|
) |
|
|
if isinstance(dataset, Dict): |
|
|
return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()} |
|
|
else: |
|
|
return create_single_dataloader(dataset, config, **kwargs) |
|
|
|
|
|
class UniRigDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
data: List[Tuple[str, str]], |
|
|
name: str, |
|
|
process_fn: Union[Callable[[List[ModelInput]], Dict]]=None, |
|
|
tokenizer: Union[TokenizerSpec, None]=None, |
|
|
transform_config: Union[TransformConfig, None]=None, |
|
|
debug: bool=False, |
|
|
data_name: str='raw_data.npz', |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.data = data |
|
|
self.name = name |
|
|
self.process_fn = process_fn |
|
|
self.tokenizer = tokenizer |
|
|
self.transform_config = transform_config |
|
|
self.debug = debug |
|
|
self.data_name = data_name |
|
|
|
|
|
if not debug: |
|
|
assert self.process_fn is not None, 'missing data processing function' |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx) -> ModelInput: |
|
|
cls, dir_path = self.data[idx] |
|
|
raw_data = RawData.load(path=os.path.join(dir_path, self.data_name)) |
|
|
asset = Asset.from_raw_data(raw_data=raw_data, cls=cls, path=dir_path, data_name=self.data_name) |
|
|
|
|
|
first_augments, second_augments = transform_asset( |
|
|
asset=asset, |
|
|
transform_config=self.transform_config, |
|
|
) |
|
|
if self.tokenizer is not None and asset.parents is not None: |
|
|
tokens = self.tokenizer.tokenize(input=asset.get_tokenize_input()) |
|
|
else: |
|
|
tokens = None |
|
|
return ModelInput( |
|
|
tokens=tokens, |
|
|
pad=None if self.tokenizer is None else self.tokenizer.pad, |
|
|
vertices=asset.sampled_vertices.astype(np.float32), |
|
|
normals=asset.sampled_normals.astype(np.float32), |
|
|
joints=None if asset.joints is None else asset.joints.astype(np.float32), |
|
|
tails=None if asset.tails is None else asset.tails.astype(np.float32), |
|
|
asset=asset, |
|
|
augments=None, |
|
|
) |
|
|
|
|
|
def _collate_fn_debug(self, batch): |
|
|
return batch |
|
|
|
|
|
def _collate_fn(self, batch): |
|
|
return data.dataloader.default_collate(self.process_fn(batch)) |
|
|
|
|
|
def collate_fn(self, batch): |
|
|
if self.debug: |
|
|
return self._collate_fn_debug(batch) |
|
|
return self._collate_fn(batch) |