Chris Xiao
upload files
2ca2f68
import monai
import torch
import itk
import numpy as np
import matplotlib.pyplot as plt
import random
import glob
import os.path
import argparse
import sys
def create_batch_generator(dataloader_subdivided, weights=None):
"""
Create a batch generator that samples data pairs with various segmentation availabilities.
Arguments:
dataloader_subdivided : a mapping from the labels in seg_availabilities to dataloaders
weights : a list of probabilities, one for each label in seg_availabilities;
if not provided then we weight by the number of data items of each type,
effectively sampling uniformly over the union of the datasets
Returns: batch_generator
A function that accepts a number of batches to sample and that returns a generator.
The generator will weighted-randomly pick one of the seg_availabilities and
yield the next batch from the corresponding dataloader.
"""
seg_availabilities = ['00', '01', '10', '11']
if weights is None:
weights = np.array([len(dataloader_subdivided[s]) for s in seg_availabilities])
weights = np.array(weights)
weights = weights / weights.sum()
dataloader_subdivided_as_iterators = {s: iter(d) for s, d in dataloader_subdivided.items()}
def batch_generator(num_batches_to_sample):
for _ in range(num_batches_to_sample):
seg_availability = np.random.choice(seg_availabilities, p=weights)
try:
yield next(dataloader_subdivided_as_iterators[seg_availability])
except StopIteration: # If dataloader runs out, restart it
dataloader_subdivided_as_iterators[seg_availability] =\
iter(dataloader_subdivided[seg_availability])
yield next(dataloader_subdivided_as_iterators[seg_availability])
return batch_generator