GANime / ganime /data /mnist.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
3.54 kB
import glob
import os
from typing import Literal
import numpy as np
from .base import SequenceDataset
import math
class MovingMNISTImage(SequenceDataset):
def load_data(self, dataset_path: str, split: str) -> np.ndarray:
data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
# Data is of shape (window, n_samples, width, height)
# But we want for keras something of shape (n_samples, window, width, height)
data = np.moveaxis(data, 0, 1)
# Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
data = np.expand_dims(data, axis=-1)
if split == "train":
data = data[:-1000]
else:
data = data[-1000:]
data = np.concatenate([data, data, data], axis=-1)
return data
def __getitem__(self, idx):
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
batch_x = self.data[inds, 0, ...]
batch_y = self.data[inds, 1, ...]
return batch_x, batch_y
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
return data / 255
class MovingMNIST(SequenceDataset):
def __init__(
self,
dataset_path: str,
batch_size: int,
split: Literal["train", "validation", "test"] = "train",
):
self.batch_size = batch_size
self.split = split
root_path = os.path.join(dataset_path, "moving_mnist", split)
self.paths = glob.glob(os.path.join(root_path, "*.npy"))
# self.data = self.preprocess_data(self.data)
self.indices = np.arange(len(self.paths))
self.on_epoch_end()
# def load_data(self, dataset_path: str, split: str) -> np.ndarray:
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
# # Data is of shape (window, n_samples, width, height)
# # But we want for keras something of shape (n_samples, window, width, height)
# data = np.moveaxis(data, 0, 1)
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
# data = np.expand_dims(data, axis=-1)
# if split == "train":
# data = data[:100]
# else:
# data = data[100:110]
# data = np.concatenate([data, data, data], axis=-1)
# return data
def __len__(self):
return math.ceil(len(self.paths) / self.batch_size)
def __getitem__(self, idx):
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
data = self.load_indices(inds)
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
batch_y = data[:, 1:, ...]
return batch_x, batch_y
def get_fixed_batch(self, idx):
self.fixed_indices = (
self.fixed_indices
if hasattr(self, "fixed_indices")
else self.indices[
idx * self.batch_size : (idx + 1) * self.batch_size
].copy()
)
data = self.load_indices(self.fixed_indices)
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
batch_y = data[:, 1:, ...]
return batch_x, batch_y
def load_indices(self, indices):
paths_to_load = [self.paths[index] for index in indices]
data = [np.load(path) for path in paths_to_load]
data = np.array(data)
return self.preprocess_data(data)
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
return data / 255