File size: 2,543 Bytes
9dce563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import numpy as np
import gzip
from PIL import Image
from torchvision import transforms

class CustomMNISTDataset:
    def __init__(self, dataset_path, transform=None):
        self.dataset_path = dataset_path
        self.transform = transform
        self.images, self.labels = self.load_dataset()

    def load_dataset(self):
        image_paths = []
        label_paths = []

        # Assuming the dataset consists of images and labels in the dataset path
        for file in os.listdir(self.dataset_path):
            if 'train-images-idx3-ubyte.gz' in file:
                image_paths.append(os.path.join(self.dataset_path, file))
            elif 'train-labels-idx1-ubyte.gz' in file:
                label_paths.append(os.path.join(self.dataset_path, file))

        if not image_paths or not label_paths:
            raise ValueError(f"❌ Missing image or label files in {self.dataset_path}")

        images = []
        labels = []

        # Assuming one image file and one label file
        for img_path, label_path in zip(image_paths, label_paths):
            images_data, labels_data = self.load_mnist_data(img_path, label_path)
            images.extend(images_data)
            labels.extend(labels_data)

        return images, labels

    def load_mnist_data(self, img_path, label_path):
        """Load MNIST data from .gz files."""
        with gzip.open(img_path, 'rb') as f:
            # Skip the magic number and metadata
            f.read(16)
            # Read the image data
            img_data = np.frombuffer(f.read(), dtype=np.uint8)
            img_data = img_data.reshape(-1, 28, 28)  # Reshape to 28x28 images

        with gzip.open(label_path, 'rb') as f:
            # Skip the magic number and metadata
            f.read(8)
            # Read the label data
            label_data = np.frombuffer(f.read(), dtype=np.uint8)

        images = [Image.fromarray(img) for img in img_data]  # Convert each image to a PIL Image

        # If you have any transformation, apply it here
        if self.transform:
            images = [self.transform(img) for img in images]

        return images, label_data

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.images)

    def __getitem__(self, idx):
        """Return a single image and its label at the given index."""
        image = self.images[idx]
        label = self.labels[idx]
        return image, label