DF / samplelib /SampleGeneratorImage.py
Jatin7860's picture
Upload 226 files
fcd5579 verified
raw
history blame
2.19 kB
import traceback
import cv2
import numpy as np
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType)
class SampleGeneratorImage(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs):
super().__init__(debug, batch_size)
self.initialized = False
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
samples = SampleLoader.load (SampleType.IMAGE, samples_path)
if len(samples) == 0:
if raise_on_no_data:
raise ValueError('No training data provided.')
return
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
[SubprocessGenerator ( self.batch_func, samples )]
self.generator_counter = -1
self.initialized = True
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, samples):
samples_len = len(samples)
idxs = [ *range(samples_len) ]
shuffle_idxs = []
while True:
batches = None
for n_batch in range(self.batch_size):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle (shuffle_idxs)
idx = shuffle_idxs.pop()
sample = samples[idx]
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
if batches is None:
batches = [ [] for _ in range(len(x)) ]
for i in range(len(x)):
batches[i].append ( x[i] )
yield [ np.array(batch) for batch in batches]