Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import gzip | |
from PIL import Image | |
from torchvision import transforms | |
class CustomMNISTDataset: | |
def __init__(self, dataset_path, transform=None): | |
self.dataset_path = dataset_path | |
self.transform = transform | |
self.images, self.labels = self.load_dataset() | |
def load_dataset(self): | |
image_paths = [] | |
label_paths = [] | |
# Assuming the dataset consists of images and labels in the dataset path | |
for file in os.listdir(self.dataset_path): | |
if 'train-images-idx3-ubyte.gz' in file: | |
image_paths.append(os.path.join(self.dataset_path, file)) | |
elif 'train-labels-idx1-ubyte.gz' in file: | |
label_paths.append(os.path.join(self.dataset_path, file)) | |
if not image_paths or not label_paths: | |
raise ValueError(f"❌ Missing image or label files in {self.dataset_path}") | |
images = [] | |
labels = [] | |
# Assuming one image file and one label file | |
for img_path, label_path in zip(image_paths, label_paths): | |
images_data, labels_data = self.load_mnist_data(img_path, label_path) | |
images.extend(images_data) | |
labels.extend(labels_data) | |
return images, labels | |
def load_mnist_data(self, img_path, label_path): | |
"""Load MNIST data from .gz files.""" | |
with gzip.open(img_path, 'rb') as f: | |
# Skip the magic number and metadata | |
f.read(16) | |
# Read the image data | |
img_data = np.frombuffer(f.read(), dtype=np.uint8) | |
img_data = img_data.reshape(-1, 28, 28) # Reshape to 28x28 images | |
with gzip.open(label_path, 'rb') as f: | |
# Skip the magic number and metadata | |
f.read(8) | |
# Read the label data | |
label_data = np.frombuffer(f.read(), dtype=np.uint8) | |
images = [Image.fromarray(img) for img in img_data] # Convert each image to a PIL Image | |
# If you have any transformation, apply it here | |
if self.transform: | |
images = [self.transform(img) for img in images] | |
return images, label_data | |
def __len__(self): | |
"""Return the total number of images in the dataset.""" | |
return len(self.images) | |
def __getitem__(self, idx): | |
"""Return a single image and its label at the given index.""" | |
image = self.images[idx] | |
label = self.labels[idx] | |
return image, label | |