Fucius's picture
Upload 52 files
ad5354d verified
raw
history blame
No virus
8.05 kB
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import copy
import warnings
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler
from src.efficientvit.apps.data_provider.random_resolution import RRSController
from src.efficientvit.models.utils import val2tuple
__all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
def parse_image_size(size: int or str) -> tuple[int, int]:
if isinstance(size, str):
size = [int(val) for val in size.split("-")]
return size[0], size[1]
else:
return val2tuple(size, 2)
def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
g = torch.Generator()
g.manual_seed(seed) # set random seed before sampling validation set
rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
dropped_indexes = rand_indexes[:drop_size]
remaining_indexes = rand_indexes[drop_size:]
dropped_dataset = copy.deepcopy(dataset)
for key in keys:
setattr(
dropped_dataset,
key,
[getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
)
setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
return dataset, dropped_dataset
class DataProvider:
data_keys = ("samples",)
mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
SUB_SEED = 937162211 # random seed for sampling subset
VALID_SEED = 2147483647 # random seed for the validation set
name: str
def __init__(
self,
train_batch_size: int,
test_batch_size: int or None,
valid_size: int or float or None,
n_worker: int,
image_size: int or list[int] or str or list[str],
num_replicas: int or None = None,
rank: int or None = None,
train_ratio: float or None = None,
drop_last: bool = False,
):
warnings.filterwarnings("ignore")
super().__init__()
# batch_size & valid_size
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size or self.train_batch_size
self.valid_size = valid_size
# image size
if isinstance(image_size, list):
self.image_size = [parse_image_size(size) for size in image_size]
self.image_size.sort() # e.g., 160 -> 224
RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
else:
self.image_size = parse_image_size(image_size)
RRSController.IMAGE_SIZE_LIST = [self.image_size]
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
# distributed configs
self.num_replicas = num_replicas
self.rank = rank
# build datasets
train_dataset, val_dataset, test_dataset = self.build_datasets()
if train_ratio is not None and train_ratio < 1.0:
assert 0 < train_ratio < 1
_, train_dataset = random_drop_data(
train_dataset,
int(train_ratio * len(train_dataset)),
self.SUB_SEED,
self.data_keys,
)
# build data loader
self.train = self.build_dataloader(
train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
)
self.valid = self.build_dataloader(
val_dataset, test_batch_size, n_worker, drop_last=False, train=False
)
self.test = self.build_dataloader(
test_dataset, test_batch_size, n_worker, drop_last=False, train=False
)
if self.valid is None:
self.valid = self.test
self.sub_train = None
@property
def data_shape(self) -> tuple[int, ...]:
return 3, self.active_image_size[0], self.active_image_size[1]
def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
raise NotImplementedError
def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
raise NotImplementedError
def build_datasets(self) -> tuple[any, any, any]:
raise NotImplementedError
def build_dataloader(
self,
dataset: any or None,
batch_size: int,
n_worker: int,
drop_last: bool,
train: bool,
):
if dataset is None:
return None
if isinstance(self.image_size, list) and train:
from efficientvit.apps.data_provider.random_resolution._data_loader import \
RRSDataLoader
dataloader_class = RRSDataLoader
else:
dataloader_class = torch.utils.data.DataLoader
if self.num_replicas is None:
return dataloader_class(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_worker,
pin_memory=True,
drop_last=drop_last,
)
else:
sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
return dataloader_class(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=n_worker,
pin_memory=True,
drop_last=drop_last,
)
def set_epoch(self, epoch: int) -> None:
RRSController.set_epoch(epoch, len(self.train))
if isinstance(self.train.sampler, DistributedSampler):
self.train.sampler.set_epoch(epoch)
def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
self.active_image_size = val2tuple(new_size, 2)
new_transform = self.build_valid_transform(self.active_image_size)
# change the transform of the valid and test set
self.valid.dataset.transform = self.test.dataset.transform = new_transform
def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
if self.valid_size is not None:
if 0 < self.valid_size < 1:
valid_size = int(self.valid_size * len(train_dataset))
else:
assert self.valid_size >= 1
valid_size = int(self.valid_size)
train_dataset, val_dataset = random_drop_data(
train_dataset,
valid_size,
self.VALID_SEED,
self.data_keys,
)
val_dataset.transform = valid_transform
else:
val_dataset = None
return train_dataset, val_dataset
def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
# used for resetting BN running statistics
if self.sub_train is None:
self.sub_train = {}
if self.active_image_size in self.sub_train:
return self.sub_train[self.active_image_size]
# construct dataset and dataloader
train_dataset = copy.deepcopy(self.train.dataset)
if n_samples < len(train_dataset):
_, train_dataset = random_drop_data(
train_dataset,
n_samples,
self.SUB_SEED,
self.data_keys,
)
RRSController.ACTIVE_SIZE = self.active_image_size
train_dataset.transform = self.build_train_transform(
image_size=self.active_image_size
)
data_loader = self.build_dataloader(
train_dataset, batch_size, self.train.num_workers, True, False
)
# pre-fetch data
self.sub_train[self.active_image_size] = [
data
for data in data_loader
for _ in range(max(1, n_samples // len(train_dataset)))
]
return self.sub_train[self.active_image_size]