EverythingIsAFont / dataset_loader.py
taellinglin's picture
Upload 61 files
9dce563 verified
import os
import numpy as np
import gzip
from PIL import Image
from torchvision import transforms
class CustomMNISTDataset:
def __init__(self, dataset_path, transform=None):
self.dataset_path = dataset_path
self.transform = transform
self.images, self.labels = self.load_dataset()
def load_dataset(self):
image_paths = []
label_paths = []
# Assuming the dataset consists of images and labels in the dataset path
for file in os.listdir(self.dataset_path):
if 'train-images-idx3-ubyte.gz' in file:
image_paths.append(os.path.join(self.dataset_path, file))
elif 'train-labels-idx1-ubyte.gz' in file:
label_paths.append(os.path.join(self.dataset_path, file))
if not image_paths or not label_paths:
raise ValueError(f"❌ Missing image or label files in {self.dataset_path}")
images = []
labels = []
# Assuming one image file and one label file
for img_path, label_path in zip(image_paths, label_paths):
images_data, labels_data = self.load_mnist_data(img_path, label_path)
images.extend(images_data)
labels.extend(labels_data)
return images, labels
def load_mnist_data(self, img_path, label_path):
"""Load MNIST data from .gz files."""
with gzip.open(img_path, 'rb') as f:
# Skip the magic number and metadata
f.read(16)
# Read the image data
img_data = np.frombuffer(f.read(), dtype=np.uint8)
img_data = img_data.reshape(-1, 28, 28) # Reshape to 28x28 images
with gzip.open(label_path, 'rb') as f:
# Skip the magic number and metadata
f.read(8)
# Read the label data
label_data = np.frombuffer(f.read(), dtype=np.uint8)
images = [Image.fromarray(img) for img in img_data] # Convert each image to a PIL Image
# If you have any transformation, apply it here
if self.transform:
images = [self.transform(img) for img in images]
return images, label_data
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.images)
def __getitem__(self, idx):
"""Return a single image and its label at the given index."""
image = self.images[idx]
label = self.labels[idx]
return image, label