trafficflow-api / ZIP /utils /data_utils.py
Ha Trong Nguyen
Initial commit for HuggingFace Space backend
eb8e8ab
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose
import os, sys
from argparse import ArgumentParser
from typing import Union, Tuple, List, Dict
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import datasets
def calc_bin_center(
bins: List[Tuple[float, float]],
count_stats: Dict[int, int],
) -> Tuple[List[float], List[int]]:
"""
Calculate the representative value for each bin based on the count statistics.
`bins` may look like: [(0, 0), (1, 1), (2, 3), (4, 6), (7, float('inf'))]
`count_stats` may look like: {0: 10, 1: 20, 2: 30, 3: 40, 4: 50, 5: 60, 6: 70, 7: 80, 8: 90, 9: 100}
In this example, for bin (2, 3), we have 30 samples of 2 and 40 samples of 3 that fall into this bin.
The representative value for this bin is (30 * 2 + 40 * 3) / (30 + 40) = 2.6.
The returned list will have the same length as `bins`, and each element is the representative value for the corresponding bin.
"""
bin_counts = [0] * len(bins)
bin_sums = [0] * len(bins)
for k, v in count_stats.items():
for i, (start, end) in enumerate(bins):
if start <= int(k) <= end:
bin_counts[i] += int(v)
bin_sums[i] += int(v) * int(k)
break
bin_centers = []
for i, (s, c) in enumerate(zip(bin_sums, bin_counts)):
if c > 0:
bin_centers.append(s / c)
else:
start, end = bins[i]
if end == float('inf'):
bin_centers.append(float(start))
else:
bin_centers.append(float(start + end) / 2)
return bin_centers, bin_counts
def get_dataloader(args: ArgumentParser, split: str = "train") -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]:
ddp = args.nprocs > 1
if split == "train": # train, strong augmentation
transforms = [
datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.aug_min_scale, args.aug_max_scale)),
datasets.RandomHorizontalFlip(),
]
if args.aug_brightness > 0 or args.aug_contrast > 0 or args.aug_saturation > 0 or args.aug_hue > 0:
transforms.append(datasets.ColorJitter(
brightness=args.aug_brightness, contrast=args.aug_contrast, saturation=args.aug_saturation, hue=args.aug_hue
))
if args.aug_blur_prob > 0 and args.aug_kernel_size > 0:
transforms.append(datasets.RandomApply([
datasets.GaussianBlur(kernel_size=args.aug_kernel_size),
], p=args.aug_blur_prob))
if args.aug_saltiness > 0 or args.aug_spiciness > 0:
transforms.append(datasets.PepperSaltNoise(
saltiness=args.aug_saltiness, spiciness=args.aug_spiciness,
))
transforms = Compose(transforms)
elif args.sliding_window and args.resize_to_multiple:
transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride)
else:
transforms = None
dataset_class = datasets.InMemoryCrowd if args.in_memory_dataset else datasets.Crowd
prefetch_factor = None if args.num_workers == 0 else 3
persistent_workers = False if args.num_workers == 0 else True
dataset = dataset_class(
dataset=args.dataset,
split=split,
transforms=transforms,
sigma=None,
return_filename=False,
num_crops=args.num_crops if split == "train" else 1,
num_classes=args.num_classes,
)
if ddp and split == "train": # data_loader for training in DDP
sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=True, seed=args.seed+args.local_rank)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader, sampler
elif (not ddp) and split == "train": # data_loader for training
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader, None
elif ddp and split == "val":
sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=False)
data_loader = DataLoader(
dataset,
batch_size=1, # Use batch size 1 for evaluation
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader
else: # (not ddp) and split == "val"
data_loader = DataLoader(
dataset,
batch_size=1, # Use batch size 1 for evaluation
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader