|
import os |
|
import torch |
|
from torchvision import transforms, datasets |
|
from albumentations import ( |
|
HorizontalFlip, |
|
VerticalFlip, |
|
ShiftScaleRotate, |
|
CLAHE, |
|
RandomRotate90, |
|
Transpose, |
|
ShiftScaleRotate, |
|
HueSaturationValue, |
|
GaussNoise, |
|
Sharpen, |
|
Emboss, |
|
RandomBrightnessContrast, |
|
OneOf, |
|
Compose, |
|
) |
|
import numpy as np |
|
from PIL import Image |
|
|
|
torch.hub.set_dir('./cache') |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
|
|
|
def strong_aug(p=0.5): |
|
return Compose( |
|
[ |
|
RandomRotate90(p=0.2), |
|
Transpose(p=0.2), |
|
HorizontalFlip(p=0.5), |
|
VerticalFlip(p=0.5), |
|
OneOf( |
|
[ |
|
GaussNoise(), |
|
], |
|
p=0.2, |
|
), |
|
ShiftScaleRotate(p=0.2), |
|
OneOf( |
|
[ |
|
CLAHE(clip_limit=2), |
|
Sharpen(), |
|
Emboss(), |
|
RandomBrightnessContrast(), |
|
], |
|
p=0.2, |
|
), |
|
HueSaturationValue(p=0.2), |
|
], |
|
p=p, |
|
) |
|
|
|
|
|
def augment(aug, image): |
|
return aug(image=image)["image"] |
|
|
|
|
|
class Aug(object): |
|
def __call__(self, img): |
|
aug = strong_aug(p=0.9) |
|
return Image.fromarray(augment(aug, np.array(img))) |
|
|
|
|
|
def normalize_data(): |
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
|
|
return { |
|
"train": transforms.Compose( |
|
[Aug(), transforms.ToTensor(), transforms.Normalize(mean, std)] |
|
), |
|
"valid": transforms.Compose( |
|
[transforms.ToTensor(), transforms.Normalize(mean, std)] |
|
), |
|
"test": transforms.Compose( |
|
[transforms.ToTensor(), transforms.Normalize(mean, std)] |
|
), |
|
"vid": transforms.Compose([transforms.Normalize(mean, std)]), |
|
} |
|
|
|
|
|
def load_data(data_dir="sample/", batch_size=4): |
|
data_dir = data_dir |
|
image_datasets = { |
|
x: datasets.ImageFolder(os.path.join(data_dir, x), normalize_data()[x]) |
|
for x in ["train", "valid", "test"] |
|
} |
|
|
|
|
|
|
|
|
|
|
|
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "valid", "test"]} |
|
|
|
train_dataloaders = torch.utils.data.DataLoader( |
|
image_datasets["train"], |
|
batch_size, |
|
shuffle=True, |
|
num_workers=0, |
|
pin_memory=True, |
|
) |
|
validation_dataloaders = torch.utils.data.DataLoader( |
|
image_datasets["valid"], |
|
batch_size, |
|
shuffle=False, |
|
num_workers=0, |
|
pin_memory=True, |
|
) |
|
test_dataloaders = torch.utils.data.DataLoader( |
|
image_datasets["test"], |
|
batch_size, |
|
shuffle=False, |
|
num_workers=0, |
|
pin_memory=True, |
|
) |
|
|
|
dataloaders = { |
|
"train": train_dataloaders, |
|
"validation": validation_dataloaders, |
|
"test": test_dataloaders, |
|
} |
|
|
|
return dataloaders, dataset_sizes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|