Spaces:
No application file
No application file
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import random | |
| from functools import partial | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.data as data | |
| import torchvision | |
| import torchvision.transforms.v2 as VT | |
| from torch.utils.data import default_collate | |
| from torchvision.transforms.v2 import InterpolationMode | |
| from torchvision.transforms.v2 import functional as VF | |
| from ..core import register | |
| torchvision.disable_beta_transforms_warning() | |
| __all__ = [ | |
| "DataLoader", | |
| "BaseCollateFunction", | |
| "BatchImageCollateFunction", | |
| "batch_image_collate_fn", | |
| ] | |
| class DataLoader(data.DataLoader): | |
| __inject__ = ["dataset", "collate_fn"] | |
| def __repr__(self) -> str: | |
| format_string = self.__class__.__name__ + "(" | |
| for n in ["dataset", "batch_size", "num_workers", "drop_last", "collate_fn"]: | |
| format_string += "\n" | |
| format_string += " {0}: {1}".format(n, getattr(self, n)) | |
| format_string += "\n)" | |
| return format_string | |
| def set_epoch(self, epoch): | |
| self._epoch = epoch | |
| self.dataset.set_epoch(epoch) | |
| self.collate_fn.set_epoch(epoch) | |
| def epoch(self): | |
| return self._epoch if hasattr(self, "_epoch") else -1 | |
| def shuffle(self): | |
| return self._shuffle | |
| def shuffle(self, shuffle): | |
| assert isinstance(shuffle, bool), "shuffle must be a boolean" | |
| self._shuffle = shuffle | |
| def batch_image_collate_fn(items): | |
| """only batch image""" | |
| return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items] | |
| class BaseCollateFunction(object): | |
| def set_epoch(self, epoch): | |
| self._epoch = epoch | |
| def epoch(self): | |
| return self._epoch if hasattr(self, "_epoch") else -1 | |
| def __call__(self, items): | |
| raise NotImplementedError("") | |
| def generate_scales(base_size, base_size_repeat): | |
| scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32 | |
| scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)] | |
| scales += [base_size] * base_size_repeat | |
| scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)] | |
| return scales | |
| class BatchImageCollateFunction(BaseCollateFunction): | |
| def __init__( | |
| self, | |
| stop_epoch=None, | |
| ema_restart_decay=0.9999, | |
| base_size=640, | |
| base_size_repeat=None, | |
| ) -> None: | |
| super().__init__() | |
| self.base_size = base_size | |
| self.scales = ( | |
| generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None | |
| ) | |
| self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000 | |
| self.ema_restart_decay = ema_restart_decay | |
| # self.interpolation = interpolation | |
| def __call__(self, items): | |
| images = torch.cat([x[0][None] for x in items], dim=0) | |
| targets = [x[1] for x in items] | |
| if self.scales is not None and self.epoch < self.stop_epoch: | |
| # sz = random.choice(self.scales) | |
| # sz = [sz] if isinstance(sz, int) else list(sz) | |
| # VF.resize(inpt, sz, interpolation=self.interpolation) | |
| sz = random.choice(self.scales) | |
| images = F.interpolate(images, size=sz) | |
| if "masks" in targets[0]: | |
| for tg in targets: | |
| tg["masks"] = F.interpolate(tg["masks"], size=sz, mode="nearest") | |
| raise NotImplementedError("") | |
| return images, targets | |