mimosa-ai / dataset /loader.py
vivekk3's picture
Upload folder using huggingface_hub
9c4b01e verified
import os
import torch
from torchvision import transforms, datasets
from albumentations import (
HorizontalFlip,
VerticalFlip,
ShiftScaleRotate,
CLAHE,
RandomRotate90,
Transpose,
ShiftScaleRotate,
HueSaturationValue,
GaussNoise,
Sharpen,
Emboss,
RandomBrightnessContrast,
OneOf,
Compose,
)
import numpy as np
from PIL import Image
torch.hub.set_dir('./cache')
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
def strong_aug(p=0.5):
return Compose(
[
RandomRotate90(p=0.2),
Transpose(p=0.2),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
OneOf(
[
GaussNoise(),
],
p=0.2,
),
ShiftScaleRotate(p=0.2),
OneOf(
[
CLAHE(clip_limit=2),
Sharpen(),
Emboss(),
RandomBrightnessContrast(),
],
p=0.2,
),
HueSaturationValue(p=0.2),
],
p=p,
)
def augment(aug, image):
return aug(image=image)["image"]
class Aug(object):
def __call__(self, img):
aug = strong_aug(p=0.9)
return Image.fromarray(augment(aug, np.array(img)))
def normalize_data():
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
return {
"train": transforms.Compose(
[Aug(), transforms.ToTensor(), transforms.Normalize(mean, std)]
),
"valid": transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean, std)]
),
"test": transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean, std)]
),
"vid": transforms.Compose([transforms.Normalize(mean, std)]),
}
def load_data(data_dir="sample/", batch_size=4):
data_dir = data_dir
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x), normalize_data()[x])
for x in ["train", "valid", "test"]
}
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size,
# shuffle=True, num_workers=0, pin_memory=True)
# for x in ['train', 'validation', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "valid", "test"]}
train_dataloaders = torch.utils.data.DataLoader(
image_datasets["train"],
batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
)
validation_dataloaders = torch.utils.data.DataLoader(
image_datasets["valid"],
batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
test_dataloaders = torch.utils.data.DataLoader(
image_datasets["test"],
batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
dataloaders = {
"train": train_dataloaders,
"validation": validation_dataloaders,
"test": test_dataloaders,
}
return dataloaders, dataset_sizes
# def load_checkpoint(model, optimizer, filename=None):
# start_epoch = 0
# log_loss = 0
# if os.path.isfile(filename):
# print("=> loading checkpoint '{}'".format(filename))
# checkpoint = torch.load(filename)
# start_epoch = checkpoint["epoch"]
# model.load_state_dict(checkpoint["state_dict"], strict=False)
# optimizer.load_state_dict(checkpoint["optimizer"])
# log_loss = checkpoint["min_loss"]
# print(
# "=> loaded checkpoint '{}' (epoch {})".format(filename, checkpoint["epoch"])
# )
# else:
# print("=> no checkpoint found at '{}'".format(filename))
# return model, optimizer, start_epoch, log_loss