FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# * Loader Factory, Fast Collate, CUDA Prefetcher
# * Prefetcher and Fast Collate inspired by NVIDIA APEX example at
# * https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
# * Copyright 2019, Ross Wightman
# *--------------------------------------------------------------------------------------------*/
from functools import partial
import torch
from timm.data.dataset import IterableImageDataset
from timm.data.distributed_sampler import (OrderedDistributedSampler,
RepeatAugSampler)
from timm.data.loader import (MultiEpochsDataLoader, PrefetchLoader,
_worker_init, fast_collate)
from image_classification.pt.src.datasets.augmentations.augs import (
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
class PredictionDataset(Dataset):
def __init__(self, folder_path, transform=None):
self.folder_path = folder_path
self.image_paths = [
os.path.join(folder_path, f)
for f in os.listdir(folder_path)
if f.lower().endswith((".png", ".jpg", ".jpeg"))
]
self.transform = transform or transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def __getitem__(self, idx):
img_path = self.image_paths[idx]
img = Image.open(img_path).convert("RGB")
img = self.transform(img)
return img, img_path
def __len__(self):
return len(self.image_paths)
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
no_aug=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
num_aug_repeats=0,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
collate_fn=None,
pin_memory=False,
fp16=False, # deprecated, use img_dtype
img_dtype=torch.float32,
device=torch.device('cuda'),
use_multi_epochs_loader=False,
persistent_workers=True,
worker_seeding='all',
):
if isinstance(input_size, int):
input_size = (3, input_size, input_size)
if isinstance(dataset, IterableImageDataset):
# give Iterable datasets early knowledge of num_workers so that sample estimates
# are correct before worker processes are launched
dataset.set_loader_cfg(num_workers=num_workers)
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
if num_aug_repeats:
sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
else:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
# This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset)
else:
assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader
loader_args = {
'batch_size': batch_size,
'shuffle': not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
'num_workers': num_workers,
'sampler': sampler,
'collate_fn': collate_fn,
'pin_memory': pin_memory,
'drop_last': is_training,
'worker_init_fn': partial(_worker_init, worker_seeding=worker_seeding),
'persistent_workers': persistent_workers
}
try:
loader = loader_class(dataset, **loader_args)
except TypeError:
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
if use_prefetcher:
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
loader = PrefetchLoader( # pylint: disable=unexpected-keyword-arg
loader,
mean=mean,
std=std,
channels=input_size[0],
device=device,
fp16=fp16, # deprecated, use img_dtype
img_dtype=img_dtype,
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
)
return loader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler as DS
def get_dataloader(
dataset, batch_size=32, num_workers=4, fp16=False, distributed=False, shuffle=False,
collate_fn=None, device="cuda"
):
if collate_fn is None:
collate_fn = default_collate
def half_precision(x):
x = collate_fn(x)
x = [_x.half() if isinstance(_x, torch.FloatTensor) else _x for _x in x]
return x
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn= half_precision if fp16 else collate_fn,
sampler=DS(dataset) if distributed else None,
)
return dataloader
from timm import utils
from timm.data import FastCollateMixup
def prepare_kwargs_for_dataloader(cfg):
#####################################
# Hydra config to flat config with last keys
#####################################
from types import SimpleNamespace
from common.utils import flatten_config
# exceptional handling of scale variable. as config has two variables of same name.
# and when u flatten config they both becomes same
scale = None
if "data_augmentation" in cfg and "scale" in cfg.data_augmentation:
scale = cfg.data_augmentation.scale
if not scale:
scale = [0.08, 1.0]
# convert hydra heirarchical config to flat config
args_raw = flatten_config(cfg)
# convert dict based config to argument parser type
if isinstance(args_raw, dict):
args = SimpleNamespace(**args_raw)
else:
args = args_raw
#####################################
args.prefetcher = not getattr(args, 'no_prefetcher', False)
if not hasattr(args, "batch_size") or args.batch_size is None:
args.batch_size = 128
if not hasattr(args, "validation_batch_size") or args.validation_batch_size is None:
args.validation_batch_size = args.batch_size
num_aug_splits = 0
collate_fn = None
args.mixup = getattr(args, "mixup", 0)
args.cutmix = getattr(args, "cutmix", 0)
args.cutmix_minmax = getattr(args, "cutmix_minmax", None)
mixup_active = (args.mixup > 0) or (args.cutmix > 0) or (args.cutmix_minmax is not None)
# with current config following "if" is false
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=getattr(args, "mixup_prob", 1.0),
switch_prob=getattr(args, "mixup_switch_prob", 0.5),
mode=getattr(args, "mixup_mode", 'batch'),
label_smoothing=getattr(args, "smoothing", 0.1),
num_classes=args.num_classes
)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
device = args.device # already set in the main py
if getattr(args, 'aug_splits', 0) > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
extra_loader_kwargs = {}
if getattr(args, 'dataset_name', '') in ['imagenet','custom']:
extra_loader_kwargs = {
'train_split': getattr(args, 'train_split', 'train'),
'val_split': getattr(args, 'val_split', 'val'),
'class_map': getattr(args, 'class_map', None),
'seed': getattr(args, 'seed', None),
'repeats': getattr(args, 'epoch_repeats', None),
}
# create data loaders w/ augmentation pipeline
test_interpolation = 'bicubic'
train_interpolation = getattr(args, 'train_interpolation', "random")
args.no_aug = getattr(args, 'no_aug', False)
if args.no_aug or not train_interpolation:
train_interpolation = test_interpolation
args.mean= getattr(args, 'mean', [0.485, 0.456, 0.406])
args.std= getattr(args, 'std', [0.229, 0.224, 0.225])
batch_size = getattr(args, 'batch_size', 128)
validation_batch_size = getattr(args, 'validation_batch_size', batch_size)
loader_kwargs = dict(
img_size=tuple(args.input_shape),
batch_size=batch_size,
test_batch_size=validation_batch_size,
download=getattr(args, 'data_download', False),
distributed=getattr(args, 'distributed', False), # created same way as before Nikhil
use_prefetcher=args.prefetcher, # created same way as before Nikhil
no_aug=args.no_aug,
re_prob=getattr(args, 'reprob', 0),
re_mode=getattr(args, 'remode', 'pixel'),
re_count=getattr(args, 'recount', 1),
re_split=getattr(args, 'resplit', False),
scale=scale,
ratio=getattr(args, 'ratio', [0.75, 1.33]),
hflip=getattr(args, 'hflip', 0.5),
vflip=getattr(args, 'vflip', 0.0),
color_jitter=getattr(args, 'color_jitter', 0.4),
auto_augment=getattr(args, 'aa', None),
num_aug_repeats=getattr(args, 'aug_repeats', 0),
num_aug_splits=num_aug_splits,
train_interpolation=train_interpolation,
test_interpolation=test_interpolation, # the way it was done in
mean=args.mean, # from config directly Nikhil
std=args.std, # from config directly Nikhil
num_workers=getattr(args, 'workers', 4),
collate_fn=collate_fn, # created same way as before Nikhil
pin_memory=getattr(args, 'pin_mem', False),
device=device, # TODO multi gpu
use_multi_epochs_loader=getattr(args, 'use_multi_epochs_loader', False),
worker_seeding=getattr(args, 'worker_seeding', False),
test_path=getattr(args, 'test_path', None),
prediction_path=getattr(args, 'prediction_path', None),
quantization_path=getattr(args, 'quantization_path', None),
quantization_split = getattr(args, 'quantization_split', 0.1),
num_classes = args.num_classes, # has to be there
**extra_loader_kwargs)
return loader_kwargs
def get_from_config(cfg, key_path, default=None):
"""
Safely get a nested value from cfg using dot-separated keys.
Works for OmegaConf DictConfig, dict, or objects with attributes.
Example:
get_from_config(cfg, "a.b.c", default=42)
"""
keys = key_path.split(".")
current = cfg
for key in keys:
if isinstance(current, dict):
if key not in current:
return default
current = current[key]
elif hasattr(current, key):
current = get_from_config(current, key)
else:
return default
return current