from typing import Tuple | |
import numpy as np | |
import tensorflow as tf | |
import os | |
from tensorflow.keras.utils import Sequence | |
from abc import ABC, abstractmethod | |
from typing import Literal | |
import math | |
from ganime.data.experimental import ImageDataset | |
# class SequenceDataset(Sequence): | |
# def __init__( | |
# self, | |
# dataset_path: str, | |
# batch_size: int, | |
# split: Literal["train", "validation", "test"] = "train", | |
# ): | |
# self.batch_size = batch_size | |
# self.split = split | |
# self.data = self.load_data(dataset_path, split) | |
# self.data = self.preprocess_data(self.data) | |
# self.indices = np.arange(self.data.shape[0]) | |
# self.on_epoch_end() | |
# @abstractmethod | |
# def load_data(self, dataset_path: str, split: str) -> np.ndarray: | |
# pass | |
# def preprocess_data(self, data: np.ndarray) -> np.ndarray: | |
# return data | |
# def __len__(self): | |
# return math.ceil(len(self.data) / self.batch_size) | |
# def __getitem__(self, idx): | |
# inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] | |
# batch_x = self.data[inds] | |
# batch_y = batch_x | |
# return batch_x, batch_y | |
# def get_fixed_batch(self, idx): | |
# self.fixed_indices = ( | |
# self.fixed_indices | |
# if hasattr(self, "fixed_indices") | |
# else self.indices[ | |
# idx * self.batch_size : (idx + 1) * self.batch_size | |
# ].copy() | |
# ) | |
# batch_x = self.data[self.fixed_indices] | |
# batch_y = batch_x | |
# return batch_x, batch_y | |
# def on_epoch_end(self): | |
# np.random.shuffle(self.indices) | |
# def load_kny_images( | |
# dataset_path: str, batch_size: int | |
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: | |
# import skvideo.io | |
# if os.path.exists(os.path.join(dataset_path, "kny", "kny_images.npy")): | |
# data = np.load(os.path.join(dataset_path, "kny", "kny_images.npy")) | |
# else: | |
# data = skvideo.io.vread(os.path.join(dataset_path, "kny", "01.mp4")) | |
# np.random.shuffle(data) | |
# def _preprocess(sample): | |
# image = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. | |
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. | |
# image = tf.image.resize(image, [64, 64]) | |
# return image, image | |
# train_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[:5000]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# test_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[5000:6000]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# return train_dataset, test_dataset, data.shape[1:] | |
# def load_moving_mnist_vae( | |
# dataset_path: str, batch_size: int | |
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: | |
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) | |
# data.shape | |
# # We can see that data is of shape (window, n_samples, width, height) | |
# # But we want for keras something of shape (n_samples, window, width, height) | |
# data = np.moveaxis(data, 0, 1) | |
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) | |
# data = np.expand_dims(data, axis=-1) | |
# def _preprocess(sample): | |
# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. | |
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. | |
# return video, video | |
# train_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[:9000]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# test_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[9000:]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# return train_dataset, test_dataset, data.shape[1:] | |
# def load_moving_mnist( | |
# dataset_path: str, batch_size: int | |
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: | |
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) | |
# data.shape | |
# # We can see that data is of shape (window, n_samples, width, height) | |
# # But we want for keras something of shape (n_samples, window, width, height) | |
# data = np.moveaxis(data, 0, 1) | |
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) | |
# data = np.expand_dims(data, axis=-1) | |
# def _preprocess(sample): | |
# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. | |
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. | |
# first_frame = video[0:1, ...] | |
# last_frame = video[-1:, ...] | |
# first_last = tf.concat([first_frame, last_frame], axis=0) | |
# return first_last, video | |
# train_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[:9000]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# test_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[9000:]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# return train_dataset, test_dataset, data.shape[1:] | |
# def load_mnist( | |
# dataset_path: str, batch_size: int | |
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: | |
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) | |
# data.shape | |
# # We can see that data is of shape (window, n_samples, width, height) | |
# # But we want for keras something of shape (n_samples, window, width, height) | |
# data = np.moveaxis(data, 0, 1) | |
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) | |
# data = np.expand_dims(data, axis=-1) | |
# def _preprocess(sample): | |
# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. | |
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. | |
# first_frame = video[0, ...] | |
# first_frame = tf.image.grayscale_to_rgb(first_frame) | |
# return first_frame, first_frame | |
# train_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[:9000]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# test_dataset = ( | |
# tf.data.Dataset.from_tensor_slices(data[9000:]) | |
# .map(_preprocess) | |
# .batch(batch_size) | |
# .prefetch(tf.data.AUTOTUNE) | |
# .shuffle(int(10e3)) | |
# ) | |
# return train_dataset, test_dataset, data.shape[1:] | |
def preprocess_image(element): | |
element = tf.reshape(element, (tf.shape(element)[0], tf.shape(element)[1], 3)) | |
element = tf.cast(element, tf.float32) / 255.0 | |
return element, element | |
def load_kny_images_light(dataset_path, batch_size): | |
dataset_length = 34045 | |
path = os.path.join(dataset_path, "kny", "images_tfrecords_light") | |
dataset = ImageDataset(path).load() | |
dataset = dataset.shuffle( | |
dataset_length, reshuffle_each_iteration=True, seed=10 | |
).map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE) | |
train_size = int(dataset_length * 0.8) | |
validation_size = int(dataset_length * 0.1) | |
train_ds = dataset.take(train_size) | |
validation_ds = dataset.skip(train_size).take(validation_size) | |
test_ds = dataset.skip(train_size + validation_size).take(validation_size) | |
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch( | |
tf.data.AUTOTUNE | |
) | |
validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch( | |
tf.data.AUTOTUNE | |
) | |
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) | |
return train_ds, validation_ds, test_ds | |
def load_kny_images(dataset_path, batch_size): | |
dataset_length = 52014 | |
path = os.path.join(dataset_path, "kny", "images_tfrecords") | |
dataset = ImageDataset(path).load() | |
dataset = dataset.shuffle(dataset_length, reshuffle_each_iteration=True).map( | |
preprocess_image, num_parallel_calls=tf.data.AUTOTUNE | |
) | |
train_size = int(dataset_length * 0.8) | |
validation_size = int(dataset_length * 0.1) | |
train_ds = dataset.take(train_size) | |
validation_ds = dataset.skip(train_size).take(validation_size) | |
test_ds = dataset.skip(train_size + validation_size).take(validation_size) | |
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch( | |
tf.data.AUTOTUNE | |
) | |
validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch( | |
tf.data.AUTOTUNE | |
) | |
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) | |
return train_ds, validation_ds, test_ds | |
def load_dataset( | |
dataset_name: str, dataset_path: str, batch_size: int | |
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: | |
# if dataset_name == "moving_mnist_vae": | |
# return load_moving_mnist_vae(dataset_path, batch_size) | |
# elif dataset_name == "moving_mnist": | |
# return load_moving_mnist(dataset_path, batch_size) | |
# elif dataset_name == "mnist": | |
# return load_mnist(dataset_path, batch_size) | |
# elif dataset_name == "kny_images": | |
# return load_kny_images(dataset_path, batch_size) | |
if dataset_name == "kny_images": | |
return load_kny_images(dataset_path, batch_size) | |
if dataset_name == "kny_images_light": | |
return load_kny_images_light(dataset_path, batch_size) | |
else: | |
raise ValueError(f"Unknown dataset: {dataset_name}") | |