|
import traceback |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
from core.joblib import SubprocessGenerator, ThisThreadGenerator |
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, |
|
SampleType) |
|
|
|
|
|
class SampleGeneratorImage(SampleGeneratorBase): |
|
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs): |
|
super().__init__(debug, batch_size) |
|
self.initialized = False |
|
self.sample_process_options = sample_process_options |
|
self.output_sample_types = output_sample_types |
|
|
|
samples = SampleLoader.load (SampleType.IMAGE, samples_path) |
|
|
|
if len(samples) == 0: |
|
if raise_on_no_data: |
|
raise ValueError('No training data provided.') |
|
return |
|
|
|
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \ |
|
[SubprocessGenerator ( self.batch_func, samples )] |
|
|
|
self.generator_counter = -1 |
|
self.initialized = True |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
self.generator_counter += 1 |
|
generator = self.generators[self.generator_counter % len(self.generators) ] |
|
return next(generator) |
|
|
|
def batch_func(self, samples): |
|
samples_len = len(samples) |
|
|
|
|
|
idxs = [ *range(samples_len) ] |
|
shuffle_idxs = [] |
|
|
|
while True: |
|
|
|
batches = None |
|
for n_batch in range(self.batch_size): |
|
|
|
if len(shuffle_idxs) == 0: |
|
shuffle_idxs = idxs.copy() |
|
np.random.shuffle (shuffle_idxs) |
|
|
|
idx = shuffle_idxs.pop() |
|
sample = samples[idx] |
|
|
|
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug) |
|
|
|
if batches is None: |
|
batches = [ [] for _ in range(len(x)) ] |
|
|
|
for i in range(len(x)): |
|
batches[i].append ( x[i] ) |
|
|
|
yield [ np.array(batch) for batch in batches] |
|
|