Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------------ | |
| # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport | |
| # Copyright (c) 2024 Borui Zhang. All Rights Reserved. | |
| # Licensed under the MIT License [see LICENSE for details] | |
| # ------------------------------------------------------------------------------ | |
| from typing import Union | |
| import torch | |
| from torch.utils.data import Subset | |
| def maybe_get_subset(dataset, subset_size: Union[int, float] = None, num_data_repeat: int = None): | |
| """ | |
| num_data_repeat is aimed at avoiding the consuming loading time for the small dataset | |
| """ | |
| if subset_size is None: | |
| return dataset | |
| else: | |
| if subset_size < 1.0: | |
| subset_size = len(dataset) * subset_size | |
| subset_size = int(subset_size) | |
| selected_indices = torch.randperm(len(dataset))[:subset_size] | |
| if num_data_repeat is not None: | |
| selected_indices = selected_indices.repeat(num_data_repeat) | |
| return Subset(dataset, selected_indices) | |
| class LoaderWrapper: | |
| """ | |
| write a dataloader class, given total steps, recursively loading data | |
| """ | |
| def __init__(self, loader, total_iterations: int = None): | |
| self.loader = loader | |
| self.total_iterations = total_iterations if total_iterations is not None else len(loader) | |
| def __iter__(self): | |
| self.generator = iter(self.loader) | |
| self.counter = 0 | |
| return self | |
| def __next__(self): | |
| if self.counter >= self.total_iterations: | |
| self.counter = 0 | |
| raise StopIteration | |
| else: | |
| self.counter += 1 | |
| try: | |
| return next(self.generator) | |
| except StopIteration: | |
| self.generator = iter(self.loader) | |
| return next(self.generator) |