|
|
|
from torch.utils.data import DataLoader |
|
from torchvision.io import read_image |
|
from torch.utils.data import Dataset |
|
from torchvision.transforms import v2 |
|
from torchvision import transforms |
|
from torchvision import datasets |
|
from PIL import Image |
|
import pandas as pd |
|
import idx2numpy, os |
|
import torch |
|
|
|
|
|
|
|
IMAGE_DIMS = 224 |
|
|
|
normal_transforms = v2.Compose([ |
|
v2.Resize(size=(IMAGE_DIMS, IMAGE_DIMS)), |
|
|
|
|
|
v2.ToDtype(torch.float32), |
|
|
|
|
|
v2.RandomRotation(degrees=(-15, 15)), |
|
|
|
|
|
transforms.Normalize((0.13066047430038452,), (0.30810782313346863,)), |
|
]) |
|
|
|
|
|
|
|
|
|
class CustomImageDataset(Dataset): |
|
""" |
|
This class must inherit from the torch.utils.data.Dataset class. |
|
And contina functions __init__, __len__, and __getitem__. |
|
""" |
|
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): |
|
self.img_labels = pd.read_csv(annotations_file) |
|
self.img_dir = img_dir |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
|
|
|
|
def __len__(self): |
|
return len(self.img_labels) |
|
|
|
def __getitem__(self, idx): |
|
"""Get the image and label at the index idx.""" |
|
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) |
|
Image.open(img_path).convert("RGB").save(img_path) |
|
image = read_image(img_path) |
|
label = self.img_labels.iloc[idx, 1] |
|
if self.transform: |
|
image = self.transform(image) |
|
if self.target_transform: |
|
label = self.target_transform(label) |
|
return image, label |
|
|
|
|
|
train_data = CustomImageDataset("./dataset/root/labels.csv", "./dataset/root/train/", transform=normal_transforms) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_size = int(0.8 * len(train_data)) |
|
test_size = len(train_data) - train_size |
|
train_dataset, test_dataset = torch.utils.data.random_split(train_data, [train_size, test_size]) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) |
|
|
|
print("Data loader and Test Loaders are ready to be used.") |
|
|
|
|
|
|
|
|
|
|