|
import monai |
|
import torch |
|
import itk |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import random |
|
import glob |
|
import os.path |
|
import argparse |
|
import sys |
|
|
|
|
|
def create_batch_generator(dataloader_subdivided, weights=None): |
|
""" |
|
Create a batch generator that samples data pairs with various segmentation availabilities. |
|
|
|
Arguments: |
|
dataloader_subdivided : a mapping from the labels in seg_availabilities to dataloaders |
|
weights : a list of probabilities, one for each label in seg_availabilities; |
|
if not provided then we weight by the number of data items of each type, |
|
effectively sampling uniformly over the union of the datasets |
|
|
|
Returns: batch_generator |
|
A function that accepts a number of batches to sample and that returns a generator. |
|
The generator will weighted-randomly pick one of the seg_availabilities and |
|
yield the next batch from the corresponding dataloader. |
|
""" |
|
seg_availabilities = ['00', '01', '10', '11'] |
|
if weights is None: |
|
weights = np.array([len(dataloader_subdivided[s]) for s in seg_availabilities]) |
|
weights = np.array(weights) |
|
weights = weights / weights.sum() |
|
dataloader_subdivided_as_iterators = {s: iter(d) for s, d in dataloader_subdivided.items()} |
|
|
|
def batch_generator(num_batches_to_sample): |
|
for _ in range(num_batches_to_sample): |
|
seg_availability = np.random.choice(seg_availabilities, p=weights) |
|
try: |
|
yield next(dataloader_subdivided_as_iterators[seg_availability]) |
|
except StopIteration: |
|
dataloader_subdivided_as_iterators[seg_availability] =\ |
|
iter(dataloader_subdivided[seg_availability]) |
|
yield next(dataloader_subdivided_as_iterators[seg_availability]) |
|
return batch_generator |
|
|