Spaces:
Configuration error
Configuration error
# Copyright (c) OpenMMLab. All rights reserved. | |
import json | |
import os | |
import random | |
from copy import deepcopy | |
from random import choice, shuffle | |
from typing import Sequence | |
from torch.utils.data import BatchSampler, Dataset, Sampler | |
from diffusion.utils.logger import get_root_logger | |
class AspectRatioBatchSampler(BatchSampler): | |
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch. | |
Args: | |
sampler (Sampler): Base sampler. | |
dataset (Dataset): Dataset providing data information. | |
batch_size (int): Size of mini-batch. | |
drop_last (bool): If ``True``, the sampler will drop the last batch if | |
its size would be less than ``batch_size``. | |
aspect_ratios (dict): The predefined aspect ratios. | |
""" | |
def __init__( | |
self, | |
sampler: Sampler, | |
dataset: Dataset, | |
batch_size: int, | |
aspect_ratios: dict, | |
drop_last: bool = False, | |
config=None, | |
valid_num=0, # take as valid aspect-ratio when sample number >= valid_num | |
hq_only=False, | |
cache_file=None, | |
caching=False, | |
**kwargs, | |
) -> None: | |
if not isinstance(sampler, Sampler): | |
raise TypeError(f"sampler should be an instance of ``Sampler``, but got {sampler}") | |
if not isinstance(batch_size, int) or batch_size <= 0: | |
raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") | |
self.sampler = sampler | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.aspect_ratios = aspect_ratios | |
self.drop_last = drop_last | |
self.hq_only = hq_only | |
self.config = config | |
self.caching = caching | |
self.cache_file = cache_file | |
self.order_check_pass = False | |
self.ratio_nums_gt = kwargs.get("ratio_nums", None) | |
assert self.ratio_nums_gt, "ratio_nums_gt must be provided." | |
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios.keys()} | |
self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num] | |
logger = ( | |
get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, "train_log.log")) | |
) | |
logger.warning( | |
f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}" | |
) | |
self.data_all = {} if caching else None | |
if os.path.exists(cache_file): | |
logger.info(f"Loading cached file for multi-scale training: {cache_file}") | |
try: | |
self.cached_idx = json.load(open(cache_file)) | |
except: | |
logger.info(f"Failed loading: {cache_file}") | |
self.cached_idx = {} | |
else: | |
logger.info(f"No cached file is found, dataloader is slow: {cache_file}") | |
self.cached_idx = {} | |
self.exist_ids = len(self.cached_idx) | |
def __iter__(self) -> Sequence[int]: | |
for idx in self.sampler: | |
data_info, closest_ratio = self._get_data_info_and_ratio(idx) | |
if not data_info: | |
continue | |
bucket = self._aspect_ratio_buckets[closest_ratio] | |
bucket.append(idx) | |
# yield a batch of indices in the same aspect ratio group | |
if len(bucket) == self.batch_size: | |
self._update_cache(bucket) | |
yield bucket[:] | |
del bucket[:] | |
for bucket in self._aspect_ratio_buckets.values(): | |
while bucket: | |
if not self.drop_last or len(bucket) == self.batch_size: | |
yield bucket[:] | |
del bucket[:] | |
def _get_data_info_and_ratio(self, idx): | |
str_idx = str(idx) | |
if self.caching: | |
if str_idx in self.cached_idx: | |
return self.cached_idx[str_idx], self.cached_idx[str_idx]["closest_ratio"] | |
data_info = self.dataset.get_data_info(int(idx)) | |
if data_info is None or ( | |
self.hq_only and "version" in data_info and data_info["version"] not in ["high_quality"] | |
): | |
return None, None | |
closest_ratio = self._get_closest_ratio(data_info["height"], data_info["width"]) | |
self.data_all[str_idx] = { | |
"height": data_info["height"], | |
"width": data_info["width"], | |
"closest_ratio": closest_ratio, | |
"key": data_info["key"], | |
} | |
return data_info, closest_ratio | |
else: | |
if self.cached_idx: | |
if self.cached_idx.get(str_idx): | |
if not self.order_check_pass or random.random() < 0.01: | |
# Ensure the cached dataset is in the same order as the original tar file | |
self._order_check(str_idx) | |
closest_ratio = self.cached_idx[str_idx]["closest_ratio"] | |
return self.cached_idx[str_idx], closest_ratio | |
data_info = self.dataset.get_data_info(int(idx)) | |
if data_info is None or ( | |
self.hq_only and "version" in data_info and data_info["version"] not in ["high_quality"] | |
): | |
return None, None | |
closest_ratio = self._get_closest_ratio(data_info["height"], data_info["width"]) | |
return data_info, closest_ratio | |
def _get_closest_ratio(self, height, width): | |
ratio = height / width | |
return min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) | |
def _order_check(self, str_idx): | |
ori_data = self.cached_idx[str_idx] | |
real_key = self.dataset.get_data_info(int(str_idx))["key"] | |
assert real_key and ori_data["key"] == real_key, ValueError( | |
f"index: {str_idx}, real key: {real_key} ori key: {ori_data['key']}" | |
) | |
self.order_check_pass = True | |
def _update_cache(self, bucket): | |
if self.caching: | |
for idx in bucket: | |
if str(idx) in self.cached_idx: | |
continue | |
self.cached_idx[str(idx)] = self.data_all.pop(str(idx)) | |
class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# Assign samples to each bucket | |
self.ratio_nums_gt = kwargs.get("ratio_nums", None) | |
assert self.ratio_nums_gt | |
self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()} | |
self.original_buckets = {} | |
self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000] | |
self.all_available_keys = deepcopy(self.current_available_bucket_keys) | |
self.exhausted_bucket_keys = [] | |
self.total_batches = len(self.sampler) // self.batch_size | |
self._aspect_ratio_count = {} | |
for k in self.all_available_keys: | |
self._aspect_ratio_count[float(k)] = 0 | |
self.original_buckets[float(k)] = [] | |
logger = get_root_logger(os.path.join(self.config.work_dir, "train_log.log")) | |
logger.warning( | |
f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}" | |
) | |
def __iter__(self) -> Sequence[int]: | |
i = 0 | |
for idx in self.sampler: | |
data_info = self.dataset.get_data_info(idx) | |
height, width = data_info["height"], data_info["width"] | |
ratio = height / width | |
closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))) | |
if closest_ratio not in self.all_available_keys: | |
continue | |
if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]: | |
self._aspect_ratio_count[closest_ratio] += 1 | |
self._aspect_ratio_buckets[closest_ratio].append(idx) | |
self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket | |
if not self.current_available_bucket_keys: | |
self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, [] | |
if closest_ratio not in self.current_available_bucket_keys: | |
continue | |
key = closest_ratio | |
bucket = self._aspect_ratio_buckets[key] | |
if len(bucket) == self.batch_size: | |
yield bucket[: self.batch_size] | |
del bucket[: self.batch_size] | |
i += 1 | |
self.exhausted_bucket_keys.append(key) | |
self.current_available_bucket_keys.remove(key) | |
for _ in range(self.total_batches - i): | |
key = choice(self.all_available_keys) | |
bucket = self._aspect_ratio_buckets[key] | |
if len(bucket) >= self.batch_size: | |
yield bucket[: self.batch_size] | |
del bucket[: self.batch_size] | |
# If a bucket is exhausted | |
if not bucket: | |
self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) | |
shuffle(self._aspect_ratio_buckets[key]) | |
else: | |
self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) | |
shuffle(self._aspect_ratio_buckets[key]) | |