Resnet50-cats_vs_dogs / load_datasets.py
TheEeeeLin's picture
Upload 25 files
b83973e
raw
history blame contribute delete
No virus
1.34 kB
import csv
import os
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
class DatasetLoader(Dataset):
def __init__(self, csv_path):
self.csv_file = csv_path
with open(self.csv_file, 'r') as file:
self.data = list(csv.reader(file))
self.current_dir = os.path.dirname(os.path.abspath(__file__))
def preprocess_image(self, image_path):
"""
Preprocess the image: Read the image, apply transformations, and return the transformed image.
"""
full_path = os.path.join(self.current_dir, 'datasets', image_path)
image = Image.open(full_path)
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return image_transform(image)
def __getitem__(self, index):
"""
Return the preprocessed image and its label at the specified index from the dataset.
"""
image_path, label = self.data[index]
image = self.preprocess_image(image_path)
return image, int(label)
def __len__(self):
"""
Return the number of items in the dataset.
"""
return len(self.data)