Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| from typing import Sequence | |
| from torch.utils.data import BatchSampler, Sampler, Dataset | |
| from random import shuffle, choice | |
| from copy import deepcopy | |
| from DiT_VAE.diffusion.utils.logger import get_root_logger | |
| class AspectRatioBatchSampler(BatchSampler): | |
| """A sampler wrapper for grouping images with similar aspect ratio into a same batch. | |
| Args: | |
| sampler (Sampler): Base sampler. | |
| dataset (Dataset): Dataset providing data information. | |
| batch_size (int): Size of mini-batch. | |
| drop_last (bool): If ``True``, the sampler will drop the last batch if | |
| its size would be less than ``batch_size``. | |
| aspect_ratios (dict): The predefined aspect ratios. | |
| """ | |
| def __init__(self, | |
| sampler: Sampler, | |
| dataset: Dataset, | |
| batch_size: int, | |
| aspect_ratios: dict, | |
| drop_last: bool = False, | |
| config=None, | |
| valid_num=0, # take as valid aspect-ratio when sample number >= valid_num | |
| **kwargs) -> None: | |
| if not isinstance(sampler, Sampler): | |
| raise TypeError('sampler should be an instance of ``Sampler``, ' | |
| f'but got {sampler}') | |
| if not isinstance(batch_size, int) or batch_size <= 0: | |
| raise ValueError('batch_size should be a positive integer value, ' | |
| f'but got batch_size={batch_size}') | |
| self.sampler = sampler | |
| self.dataset = dataset | |
| self.batch_size = batch_size | |
| self.aspect_ratios = aspect_ratios | |
| self.drop_last = drop_last | |
| self.ratio_nums_gt = kwargs.get('ratio_nums', None) | |
| self.config = config | |
| assert self.ratio_nums_gt | |
| # buckets for each aspect ratio | |
| self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} | |
| self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num] | |
| logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) | |
| logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") | |
| def __iter__(self) -> Sequence[int]: | |
| for idx in self.sampler: | |
| data_info = self.dataset.get_data_info(idx) | |
| height, width = data_info['height'], data_info['width'] | |
| ratio = height / width | |
| # find the closest aspect ratio | |
| closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) | |
| if closest_ratio not in self.current_available_bucket_keys: | |
| continue | |
| bucket = self._aspect_ratio_buckets[closest_ratio] | |
| bucket.append(idx) | |
| # yield a batch of indices in the same aspect ratio group | |
| if len(bucket) == self.batch_size: | |
| yield bucket[:] | |
| del bucket[:] | |
| # yield the rest data and reset the buckets | |
| for bucket in self._aspect_ratio_buckets.values(): | |
| while len(bucket) > 0: | |
| if len(bucket) <= self.batch_size: | |
| if not self.drop_last: | |
| yield bucket[:] | |
| bucket = [] | |
| else: | |
| yield bucket[:self.batch_size] | |
| bucket = bucket[self.batch_size:] | |
| class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # Assign samples to each bucket | |
| self.ratio_nums_gt = kwargs.get('ratio_nums', None) | |
| assert self.ratio_nums_gt | |
| self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()} | |
| self.original_buckets = {} | |
| self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000] | |
| self.all_available_keys = deepcopy(self.current_available_bucket_keys) | |
| self.exhausted_bucket_keys = [] | |
| self.total_batches = len(self.sampler) // self.batch_size | |
| self._aspect_ratio_count = {} | |
| for k in self.all_available_keys: | |
| self._aspect_ratio_count[float(k)] = 0 | |
| self.original_buckets[float(k)] = [] | |
| logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log')) | |
| logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") | |
| def __iter__(self) -> Sequence[int]: | |
| i = 0 | |
| for idx in self.sampler: | |
| data_info = self.dataset.get_data_info(idx) | |
| height, width = data_info['height'], data_info['width'] | |
| ratio = height / width | |
| closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))) | |
| if closest_ratio not in self.all_available_keys: | |
| continue | |
| if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]: | |
| self._aspect_ratio_count[closest_ratio] += 1 | |
| self._aspect_ratio_buckets[closest_ratio].append(idx) | |
| self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket | |
| if not self.current_available_bucket_keys: | |
| self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, [] | |
| if closest_ratio not in self.current_available_bucket_keys: | |
| continue | |
| key = closest_ratio | |
| bucket = self._aspect_ratio_buckets[key] | |
| if len(bucket) == self.batch_size: | |
| yield bucket[:self.batch_size] | |
| del bucket[:self.batch_size] | |
| i += 1 | |
| self.exhausted_bucket_keys.append(key) | |
| self.current_available_bucket_keys.remove(key) | |
| for _ in range(self.total_batches - i): | |
| key = choice(self.all_available_keys) | |
| bucket = self._aspect_ratio_buckets[key] | |
| if len(bucket) >= self.batch_size: | |
| yield bucket[:self.batch_size] | |
| del bucket[:self.batch_size] | |
| # If a bucket is exhausted | |
| if not bucket: | |
| self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) | |
| shuffle(self._aspect_ratio_buckets[key]) | |
| else: | |
| self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) | |
| shuffle(self._aspect_ratio_buckets[key]) | |