vishred18's picture
Upload 364 files
d5ee97c
raw
history blame contribute delete
No virus
2.32 kB
# -*- coding: utf-8 -*-
# Copyright 2020 Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract Dataset modules."""
import abc
import tensorflow as tf
class AbstractDataset(metaclass=abc.ABCMeta):
"""Abstract Dataset module for Dataset Loader."""
@abc.abstractmethod
def get_args(self):
"""Return args for generator function."""
pass
@abc.abstractmethod
def generator(self):
"""Generator function, should have args from get_args function."""
pass
@abc.abstractmethod
def get_output_dtypes(self):
"""Return output dtypes for each element from generator."""
pass
@abc.abstractmethod
def get_len_dataset(self):
"""Return number of samples on dataset."""
pass
def create(
self,
allow_cache=False,
batch_size=1,
is_shuffle=False,
map_fn=None,
reshuffle_each_iteration=True,
):
"""Create tf.dataset function."""
output_types = self.get_output_dtypes()
datasets = tf.data.Dataset.from_generator(
self.generator, output_types=output_types, args=(self.get_args())
)
if allow_cache:
datasets = datasets.cache()
if is_shuffle:
datasets = datasets.shuffle(
self.get_len_dataset(),
reshuffle_each_iteration=reshuffle_each_iteration,
)
if batch_size > 1 and map_fn is None:
raise ValueError("map function must define when batch_size > 1.")
if map_fn is not None:
datasets = datasets.map(map_fn, tf.data.experimental.AUTOTUNE)
datasets = datasets.batch(batch_size)
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
return datasets