import monai import torch import itk import numpy as np import matplotlib.pyplot as plt import random import glob import os.path import argparse import sys def create_batch_generator(dataloader_subdivided, weights=None): """ Create a batch generator that samples data pairs with various segmentation availabilities. Arguments: dataloader_subdivided : a mapping from the labels in seg_availabilities to dataloaders weights : a list of probabilities, one for each label in seg_availabilities; if not provided then we weight by the number of data items of each type, effectively sampling uniformly over the union of the datasets Returns: batch_generator A function that accepts a number of batches to sample and that returns a generator. The generator will weighted-randomly pick one of the seg_availabilities and yield the next batch from the corresponding dataloader. """ seg_availabilities = ['00', '01', '10', '11'] if weights is None: weights = np.array([len(dataloader_subdivided[s]) for s in seg_availabilities]) weights = np.array(weights) weights = weights / weights.sum() dataloader_subdivided_as_iterators = {s: iter(d) for s, d in dataloader_subdivided.items()} def batch_generator(num_batches_to_sample): for _ in range(num_batches_to_sample): seg_availability = np.random.choice(seg_availabilities, p=weights) try: yield next(dataloader_subdivided_as_iterators[seg_availability]) except StopIteration: # If dataloader runs out, restart it dataloader_subdivided_as_iterators[seg_availability] =\ iter(dataloader_subdivided[seg_availability]) yield next(dataloader_subdivided_as_iterators[seg_availability]) return batch_generator