File size: 5,286 Bytes
1025b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""

contains functionality for creating pytorch dataloaders for image classification data

"""
import os 
import torch
from torchvision import datasets, transforms 
from torch.utils.data import DataLoader 
from pathlib import Path 
import pathlib
import requests
import zipfile
from typing import Tuple, Dict, List
from torch.utils.data import Dataset
from PIL import Image

NUM_WORKERS = os.cpu_count()

# create custom dataset
def find_classes(directory: str) -> Tuple[list[str], Dict[str, int]]:
    """

    Finds the class folder names in a target directory 

    """
    # 1. get the class names by scanning the target directory 
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

    # 2. raise an error is class names couldn't be found 
    if not classes:
        raise FileNotFoundError(f"couldn't find any classes in {directory}")
    
    # 3. create a dictionary of index labels 
    class_to_idx = {class_name: i for i, class_name in enumerate(classes)}
    return classes, class_to_idx

# 1. subclass torch.utils.data.Dataset 
class ImageFolderCustom(Dataset):
    # 2. initialize the constructor
    def __init__(self, targ_dir: str, heads: list[str], transform=None, is_training: bool = True):
        # 3. create several attributes 
        # get all the image paths
        self.training = []
        self.testing = []
        for tag in heads: 
            self.img_list = list(Path(targ_dir / tag).glob("*.jpg"))
            self.train_length = int(len(self.img_list) * 0.8)
            self.training.extend(self.img_list[:self.train_length])
            self.testing.extend(self.img_list[self.train_length:])

        if is_training: 
            self.paths = self.training
        else: 
            self.paths = self.testing
        # setup transforms
        self.transform = transform
        # create classes and class_to_idx 
        self.classes, self.class_to_idx = find_classes(targ_dir)

    # 4. create a function to load images 
    def load_image(self, index: int) -> Image.Image: 
        "opens an image via a path and returns it"
        image_path = self.paths[index]
        return Image.open(image_path)
    
    # 5. overwrite __len__()
    def __len__(self) -> int: 
        return len(self.paths)
    
    # 6. overwrite __getitem__() to return a particular sample
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "returns one sample of data, data and the label (X, y)"
        img = self.load_image(index)
        class_name = self.paths[index].parent.name # expects path in format: data_folder/class_name/image.jpg
        class_idx = self.class_to_idx[class_name]

        # transform if necessary 
        if self.transform:
            return self.transform(img), class_idx
        else: 
            return img, class_idx

def create_dataloaders(

    image_dir: str,  

    heads: list[str],

    train_transform: transforms.Compose, 

    test_transform: transforms.Compose,

    batch_size: int, 

    num_workers: int=NUM_WORKERS

):
    """

    creates training and testing DataLoaders. 



    Takes in a training directory and testing directory path and turns them

    into pytorch datasets and then into pytorch dataloaders. 



    Args:

        train_dir: path to training directory. 

        test_dir: path to testing directory 

        transform: torchvision transforms to perform on training and testing data. 

        batch_size: number of samples per batch in each of the dataloaders. 

        num_workers: an integer for number of workers per dataloader.



    returns: 

        A tuple of (train_dataloader, test_dataloader, class_names).

        where class_names is a list of the target classes. 



        Example usage: 

            train_dataloader, test_dataloader, class_names = create_dataloaders(train_dir=path/to/train_dir, 

                                                                                test_dir=path/to/test_dir, 

                                                                                transform=some_transform,

                                                                                batch_size=32,

                                                                                num_workers=4)                                                                                                                 

    """

    # use ImageFolder to create datasets 
    train_data = ImageFolderCustom(targ_dir=image_dir, heads=heads, transform=train_transform, is_training=True)

    test_data = ImageFolderCustom(targ_dir=image_dir, heads=heads, transform=test_transform, is_training=False)

    # get class names 
    class_names = train_data.classes 

    # turn images into dataloaders 
    train_dataloader = DataLoader(
        train_data, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers, 
        pin_memory=True
    )
    test_dataloader = DataLoader(
        test_data, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers, 
        pin_memory=True
    )

    return train_dataloader, test_dataloader, class_names