GabrielML's picture
Init app
541501b
raw
history blame
1.47 kB
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from PIL import Image
import torch
class AnimalDataset(Dataset):
def __init__(self, df, transform=None):
self.paths = df["path"].values
self.targets = df["target"].values
self.encoded_target = df['encoded_target'].values
self.transform = transform
self.images = []
for path in tqdm(self.paths):
self.images.append(Image.open(path).convert("RGB").resize((224, 224)))
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = self.images[idx]
if self.transform:
img = self.transform(img)
target = self.targets[idx]
encoded_target = torch.tensor(self.encoded_target[idx]).type(torch.LongTensor)
return img, encoded_target, target
train_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define the transformation pipeline
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(), # Convert the images to PyTorch tensors
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])