# -*- coding: utf-8 -*- import torch import numpy as np def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() worker_id = worker_info.id # dataset = worker_info.dataset # split_size = dataset.num_records // worker_info.num_workers # # reset num_records to the true number to retain reliable length information # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] # current_id = np.random.choice(len(np.random.get_state()[1]), 1) # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) return np.random.seed(np.random.get_state()[1][0] + worker_id) def collation_fn(samples, combine_tensors=True, combine_scalars=True): """ Args: samples (list[dict]): combine_tensors: combine_scalars: Returns: """ result = {} keys = samples[0].keys() for key in keys: result[key] = [] for sample in samples: for key in keys: val = sample[key] result[key].append(val) for key in keys: val_list = result[key] if isinstance(val_list[0], (int, float)): if combine_scalars: result[key] = np.array(result[key]) elif isinstance(val_list[0], torch.Tensor): if combine_tensors: result[key] = torch.stack(val_list) elif isinstance(val_list[0], np.ndarray): if combine_tensors: result[key] = np.stack(val_list) return result