|
import glob |
|
import os |
|
from typing import Literal |
|
|
|
import numpy as np |
|
|
|
from .base import SequenceDataset |
|
import math |
|
|
|
|
|
class MovingMNISTImage(SequenceDataset): |
|
def load_data(self, dataset_path: str, split: str) -> np.ndarray: |
|
data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) |
|
|
|
|
|
data = np.moveaxis(data, 0, 1) |
|
|
|
data = np.expand_dims(data, axis=-1) |
|
if split == "train": |
|
data = data[:-1000] |
|
else: |
|
data = data[-1000:] |
|
|
|
data = np.concatenate([data, data, data], axis=-1) |
|
|
|
return data |
|
|
|
def __getitem__(self, idx): |
|
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] |
|
batch_x = self.data[inds, 0, ...] |
|
batch_y = self.data[inds, 1, ...] |
|
|
|
return batch_x, batch_y |
|
|
|
def preprocess_data(self, data: np.ndarray) -> np.ndarray: |
|
return data / 255 |
|
|
|
|
|
class MovingMNIST(SequenceDataset): |
|
def __init__( |
|
self, |
|
dataset_path: str, |
|
batch_size: int, |
|
split: Literal["train", "validation", "test"] = "train", |
|
): |
|
self.batch_size = batch_size |
|
self.split = split |
|
root_path = os.path.join(dataset_path, "moving_mnist", split) |
|
self.paths = glob.glob(os.path.join(root_path, "*.npy")) |
|
|
|
|
|
self.indices = np.arange(len(self.paths)) |
|
self.on_epoch_end() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return math.ceil(len(self.paths) / self.batch_size) |
|
|
|
def __getitem__(self, idx): |
|
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] |
|
data = self.load_indices(inds) |
|
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1) |
|
batch_y = data[:, 1:, ...] |
|
|
|
return batch_x, batch_y |
|
|
|
def get_fixed_batch(self, idx): |
|
self.fixed_indices = ( |
|
self.fixed_indices |
|
if hasattr(self, "fixed_indices") |
|
else self.indices[ |
|
idx * self.batch_size : (idx + 1) * self.batch_size |
|
].copy() |
|
) |
|
data = self.load_indices(self.fixed_indices) |
|
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1) |
|
batch_y = data[:, 1:, ...] |
|
|
|
return batch_x, batch_y |
|
|
|
def load_indices(self, indices): |
|
paths_to_load = [self.paths[index] for index in indices] |
|
data = [np.load(path) for path in paths_to_load] |
|
data = np.array(data) |
|
return self.preprocess_data(data) |
|
|
|
def preprocess_data(self, data: np.ndarray) -> np.ndarray: |
|
return data / 255 |
|
|