Spaces:
Running
on
Zero
Running
on
Zero
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import copy | |
import warnings | |
import torch.utils.data | |
from torch.utils.data.distributed import DistributedSampler | |
from src.efficientvit.apps.data_provider.random_resolution import RRSController | |
from src.efficientvit.models.utils import val2tuple | |
__all__ = ["parse_image_size", "random_drop_data", "DataProvider"] | |
def parse_image_size(size: int or str) -> tuple[int, int]: | |
if isinstance(size, str): | |
size = [int(val) for val in size.split("-")] | |
return size[0], size[1] | |
else: | |
return val2tuple(size, 2) | |
def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)): | |
g = torch.Generator() | |
g.manual_seed(seed) # set random seed before sampling validation set | |
rand_indexes = torch.randperm(len(dataset), generator=g).tolist() | |
dropped_indexes = rand_indexes[:drop_size] | |
remaining_indexes = rand_indexes[drop_size:] | |
dropped_dataset = copy.deepcopy(dataset) | |
for key in keys: | |
setattr( | |
dropped_dataset, | |
key, | |
[getattr(dropped_dataset, key)[idx] for idx in dropped_indexes], | |
) | |
setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes]) | |
return dataset, dropped_dataset | |
class DataProvider: | |
data_keys = ("samples",) | |
mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} | |
SUB_SEED = 937162211 # random seed for sampling subset | |
VALID_SEED = 2147483647 # random seed for the validation set | |
name: str | |
def __init__( | |
self, | |
train_batch_size: int, | |
test_batch_size: int or None, | |
valid_size: int or float or None, | |
n_worker: int, | |
image_size: int or list[int] or str or list[str], | |
num_replicas: int or None = None, | |
rank: int or None = None, | |
train_ratio: float or None = None, | |
drop_last: bool = False, | |
): | |
warnings.filterwarnings("ignore") | |
super().__init__() | |
# batch_size & valid_size | |
self.train_batch_size = train_batch_size | |
self.test_batch_size = test_batch_size or self.train_batch_size | |
self.valid_size = valid_size | |
# image size | |
if isinstance(image_size, list): | |
self.image_size = [parse_image_size(size) for size in image_size] | |
self.image_size.sort() # e.g., 160 -> 224 | |
RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size) | |
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1] | |
else: | |
self.image_size = parse_image_size(image_size) | |
RRSController.IMAGE_SIZE_LIST = [self.image_size] | |
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size | |
# distributed configs | |
self.num_replicas = num_replicas | |
self.rank = rank | |
# build datasets | |
train_dataset, val_dataset, test_dataset = self.build_datasets() | |
if train_ratio is not None and train_ratio < 1.0: | |
assert 0 < train_ratio < 1 | |
_, train_dataset = random_drop_data( | |
train_dataset, | |
int(train_ratio * len(train_dataset)), | |
self.SUB_SEED, | |
self.data_keys, | |
) | |
# build data loader | |
self.train = self.build_dataloader( | |
train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True | |
) | |
self.valid = self.build_dataloader( | |
val_dataset, test_batch_size, n_worker, drop_last=False, train=False | |
) | |
self.test = self.build_dataloader( | |
test_dataset, test_batch_size, n_worker, drop_last=False, train=False | |
) | |
if self.valid is None: | |
self.valid = self.test | |
self.sub_train = None | |
def data_shape(self) -> tuple[int, ...]: | |
return 3, self.active_image_size[0], self.active_image_size[1] | |
def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: | |
raise NotImplementedError | |
def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: | |
raise NotImplementedError | |
def build_datasets(self) -> tuple[any, any, any]: | |
raise NotImplementedError | |
def build_dataloader( | |
self, | |
dataset: any or None, | |
batch_size: int, | |
n_worker: int, | |
drop_last: bool, | |
train: bool, | |
): | |
if dataset is None: | |
return None | |
if isinstance(self.image_size, list) and train: | |
from efficientvit.apps.data_provider.random_resolution._data_loader import \ | |
RRSDataLoader | |
dataloader_class = RRSDataLoader | |
else: | |
dataloader_class = torch.utils.data.DataLoader | |
if self.num_replicas is None: | |
return dataloader_class( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=n_worker, | |
pin_memory=True, | |
drop_last=drop_last, | |
) | |
else: | |
sampler = DistributedSampler(dataset, self.num_replicas, self.rank) | |
return dataloader_class( | |
dataset=dataset, | |
batch_size=batch_size, | |
sampler=sampler, | |
num_workers=n_worker, | |
pin_memory=True, | |
drop_last=drop_last, | |
) | |
def set_epoch(self, epoch: int) -> None: | |
RRSController.set_epoch(epoch, len(self.train)) | |
if isinstance(self.train.sampler, DistributedSampler): | |
self.train.sampler.set_epoch(epoch) | |
def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None: | |
self.active_image_size = val2tuple(new_size, 2) | |
new_transform = self.build_valid_transform(self.active_image_size) | |
# change the transform of the valid and test set | |
self.valid.dataset.transform = self.test.dataset.transform = new_transform | |
def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]: | |
if self.valid_size is not None: | |
if 0 < self.valid_size < 1: | |
valid_size = int(self.valid_size * len(train_dataset)) | |
else: | |
assert self.valid_size >= 1 | |
valid_size = int(self.valid_size) | |
train_dataset, val_dataset = random_drop_data( | |
train_dataset, | |
valid_size, | |
self.VALID_SEED, | |
self.data_keys, | |
) | |
val_dataset.transform = valid_transform | |
else: | |
val_dataset = None | |
return train_dataset, val_dataset | |
def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any: | |
# used for resetting BN running statistics | |
if self.sub_train is None: | |
self.sub_train = {} | |
if self.active_image_size in self.sub_train: | |
return self.sub_train[self.active_image_size] | |
# construct dataset and dataloader | |
train_dataset = copy.deepcopy(self.train.dataset) | |
if n_samples < len(train_dataset): | |
_, train_dataset = random_drop_data( | |
train_dataset, | |
n_samples, | |
self.SUB_SEED, | |
self.data_keys, | |
) | |
RRSController.ACTIVE_SIZE = self.active_image_size | |
train_dataset.transform = self.build_train_transform( | |
image_size=self.active_image_size | |
) | |
data_loader = self.build_dataloader( | |
train_dataset, batch_size, self.train.num_workers, True, False | |
) | |
# pre-fetch data | |
self.sub_train[self.active_image_size] = [ | |
data | |
for data in data_loader | |
for _ in range(max(1, n_samples // len(train_dataset))) | |
] | |
return self.sub_train[self.active_image_size] | |