|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
import torchvision.transforms.v2 as T
|
|
|
|
from ...core import GLOBAL_CONFIG, register
|
|
from ._transforms import EmptyTransform
|
|
|
|
torchvision.disable_beta_transforms_warning()
|
|
|
|
|
|
@register()
|
|
class Compose(T.Compose):
|
|
def __init__(self, ops, policy=None) -> None:
|
|
transforms = []
|
|
if ops is not None:
|
|
for op in ops:
|
|
if isinstance(op, dict):
|
|
name = op.pop("type")
|
|
transform = getattr(
|
|
GLOBAL_CONFIG[name]["_pymodule"], GLOBAL_CONFIG[name]["_name"]
|
|
)(**op)
|
|
transforms.append(transform)
|
|
op["type"] = name
|
|
|
|
elif isinstance(op, nn.Module):
|
|
transforms.append(op)
|
|
|
|
else:
|
|
raise ValueError("")
|
|
else:
|
|
transforms = [
|
|
EmptyTransform(),
|
|
]
|
|
|
|
super().__init__(transforms=transforms)
|
|
|
|
if policy is None:
|
|
policy = {"name": "default"}
|
|
|
|
self.policy = policy
|
|
self.global_samples = 0
|
|
|
|
def forward(self, *inputs: Any) -> Any:
|
|
return self.get_forward(self.policy["name"])(*inputs)
|
|
|
|
def get_forward(self, name):
|
|
forwards = {
|
|
"default": self.default_forward,
|
|
"stop_epoch": self.stop_epoch_forward,
|
|
"stop_sample": self.stop_sample_forward,
|
|
}
|
|
return forwards[name]
|
|
|
|
def default_forward(self, *inputs: Any) -> Any:
|
|
sample = inputs if len(inputs) > 1 else inputs[0]
|
|
for transform in self.transforms:
|
|
sample = transform(sample)
|
|
return sample
|
|
|
|
def stop_epoch_forward(self, *inputs: Any):
|
|
sample = inputs if len(inputs) > 1 else inputs[0]
|
|
dataset = sample[-1]
|
|
cur_epoch = dataset.epoch
|
|
policy_ops = self.policy["ops"]
|
|
policy_epoch = self.policy["epoch"]
|
|
|
|
for transform in self.transforms:
|
|
if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch:
|
|
pass
|
|
else:
|
|
sample = transform(sample)
|
|
|
|
return sample
|
|
|
|
def stop_sample_forward(self, *inputs: Any):
|
|
sample = inputs if len(inputs) > 1 else inputs[0]
|
|
dataset = sample[-1]
|
|
|
|
cur_epoch = dataset.epoch
|
|
policy_ops = self.policy["ops"]
|
|
policy_sample = self.policy["sample"]
|
|
|
|
for transform in self.transforms:
|
|
if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample:
|
|
pass
|
|
else:
|
|
sample = transform(sample)
|
|
|
|
self.global_samples += 1
|
|
|
|
return sample
|
|
|