import os | |
import numpy as np | |
from .base import SequenceDataset | |
class KNYImage(SequenceDataset): | |
def load_data(self, dataset_path: str, split: str) -> np.ndarray: | |
data = np.load(os.path.join(dataset_path, "kny", "kny_images_64x128.npy")) | |
if split == "train": | |
data = data[:-5000] | |
else: | |
data = data[-5000:] | |
return data | |
def preprocess_data(self, data: np.ndarray) -> np.ndarray: | |
return data / 255 | |