import traceback import cv2 import numpy as np from core.joblib import SubprocessGenerator, ThisThreadGenerator from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) class SampleGeneratorImage(SampleGeneratorBase): def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs): super().__init__(debug, batch_size) self.initialized = False self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types samples = SampleLoader.load (SampleType.IMAGE, samples_path) if len(samples) == 0: if raise_on_no_data: raise ValueError('No training data provided.') return self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \ [SubprocessGenerator ( self.batch_func, samples )] self.generator_counter = -1 self.initialized = True def __iter__(self): return self def __next__(self): self.generator_counter += 1 generator = self.generators[self.generator_counter % len(self.generators) ] return next(generator) def batch_func(self, samples): samples_len = len(samples) idxs = [ *range(samples_len) ] shuffle_idxs = [] while True: batches = None for n_batch in range(self.batch_size): if len(shuffle_idxs) == 0: shuffle_idxs = idxs.copy() np.random.shuffle (shuffle_idxs) idx = shuffle_idxs.pop() sample = samples[idx] x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug) if batches is None: batches = [ [] for _ in range(len(x)) ] for i in range(len(x)): batches[i].append ( x[i] ) yield [ np.array(batch) for batch in batches]