|
|
|
|
|
import collections |
|
import os |
|
|
|
import torch |
|
from torch.utils.data import get_worker_info |
|
from torch.utils.data._utils.collate import ( |
|
default_collate_err_msg_format, |
|
np_str_obj_array_pattern, |
|
) |
|
from lightning_fabric.utilities.seed import pl_worker_init_function |
|
|
|
def collate(batch): |
|
"""Difference with PyTorch default_collate: it can stack other tensor-like objects. |
|
Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich |
|
https://github.com/cvg/pixloc |
|
Released under the Apache License 2.0 |
|
""" |
|
if not isinstance(batch, list): |
|
return batch |
|
|
|
|
|
batch = [elem for elem in batch if elem is not None] |
|
elem = batch[0] |
|
elem_type = type(elem) |
|
if isinstance(elem, torch.Tensor): |
|
out = None |
|
if torch.utils.data.get_worker_info() is not None: |
|
|
|
|
|
numel = sum(x.numel() for x in batch) |
|
storage = elem.storage()._new_shared(numel, device=elem.device) |
|
out = elem.new(storage).resize_(len(batch), *list(elem.size())) |
|
return torch.stack(batch, 0, out=out) |
|
elif ( |
|
elem_type.__module__ == "numpy" |
|
and elem_type.__name__ != "str_" |
|
and elem_type.__name__ != "string_" |
|
): |
|
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": |
|
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) |
|
|
|
return collate([torch.as_tensor(b) for b in batch]) |
|
elif elem.shape == (): |
|
return torch.as_tensor(batch) |
|
elif isinstance(elem, float): |
|
return torch.tensor(batch, dtype=torch.float64) |
|
elif isinstance(elem, int): |
|
return torch.tensor(batch) |
|
elif isinstance(elem, (str, bytes)): |
|
return batch |
|
elif isinstance(elem, collections.abc.Mapping): |
|
return {key: collate([d[key] for d in batch]) for key in elem} |
|
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): |
|
return elem_type(*(collate(samples) for samples in zip(*batch))) |
|
elif isinstance(elem, collections.abc.Sequence): |
|
|
|
it = iter(batch) |
|
elem_size = len(next(it)) |
|
if not all(len(elem) == elem_size for elem in it): |
|
raise RuntimeError("each element in list of batch should be of equal size") |
|
transposed = zip(*batch) |
|
return [collate(samples) for samples in transposed] |
|
else: |
|
|
|
try: |
|
return torch.stack(batch, 0) |
|
except TypeError as e: |
|
if "expected Tensor as element" in str(e): |
|
return batch |
|
else: |
|
raise e |
|
|
|
|
|
def set_num_threads(nt): |
|
"""Force numpy and other libraries to use a limited number of threads.""" |
|
try: |
|
import mkl |
|
except ImportError: |
|
pass |
|
else: |
|
mkl.set_num_threads(nt) |
|
torch.set_num_threads(1) |
|
os.environ["IPC_ENABLE"] = "1" |
|
for o in [ |
|
"OPENBLAS_NUM_THREADS", |
|
"NUMEXPR_NUM_THREADS", |
|
"OMP_NUM_THREADS", |
|
"MKL_NUM_THREADS", |
|
]: |
|
os.environ[o] = str(nt) |
|
|
|
|
|
def worker_init_fn(i): |
|
info = get_worker_info() |
|
pl_worker_init_function(info.id) |
|
num_threads = info.dataset.cfg.get("num_threads") |
|
if num_threads is not None: |
|
set_num_threads(num_threads) |