vatrpp / models /util /augmentations.py
vittoriopippi
Change imports
9c772c4
import os.path
import pickle
import random
from abc import ABC, abstractmethod
import cv2
import numpy as np
import math
import torch
import torchvision.transforms
import torchvision.transforms.functional as F
from matplotlib import pyplot as plt
# from data.dataset import CollectionTextDataset, TextDataset
def to_opencv(batch: torch.Tensor):
images = []
for image in batch:
image = image.detach().cpu().numpy()
image = (image + 1.0) / 2.0
images.append(np.squeeze(image))
return images
class RandomMorphological(torch.nn.Module):
def __init__(self, max_size: 5, max_iterations = 1, operation = cv2.MORPH_ERODE):
super().__init__()
self.elements = [cv2.MORPH_RECT, cv2.MORPH_ELLIPSE]
self.max_size = max_size
self.max_iterations = max_iterations
self.operation = operation
def forward(self, x):
device = x.device
images = to_opencv(x)
result = []
size = random.randint(1, self.max_size)
kernel = cv2.getStructuringElement(random.choice(self.elements), (size, size))
for image in images:
image = cv2.resize(image, (image.shape[1] * 2, image.shape[0] * 2))
morphed = cv2.morphologyEx(image, op=self.operation, kernel=kernel, iterations=random.randint(1, self.max_iterations))
morphed = cv2.resize(morphed, (image.shape[1] // 2, image.shape[0] // 2))
morphed = morphed * 2.0 - 1.0
result.append(torch.Tensor(morphed))
return torch.unsqueeze(torch.stack(result).to(device), dim=1)
def gauss_noise_tensor(img):
# https://github.com/pytorch/vision/issues/6192
assert isinstance(img, torch.Tensor)
dtype = img.dtype
if not img.is_floating_point():
img = img.to(torch.float32)
sigma = 0.075
out = img + sigma * (torch.randn_like(img) - 0.5)
out = torch.clamp(out, -1.0, 1.0)
if out.dtype != dtype:
out = out.to(dtype)
return out
def compute_word_width(image: torch.Tensor) -> int:
indices = torch.where((image < 0).int())[2]
index = torch.max(indices) if len(indices) > 0 else image.size(-1)
return index
class Downsize(torch.nn.Module):
def __init__(self):
super().__init__()
self.aug = torchvision.transforms.Compose([
torchvision.transforms.RandomAffine(0.0, scale=(0.8, 1.0), interpolation=torchvision.transforms.InterpolationMode.NEAREST, fill=1.0),
torchvision.transforms.GaussianBlur(3, sigma=0.3)
])
def forward(self, x):
return self.aug(x)
class OCRAugment(torch.nn.Module):
def __init__(self, prob: float = 0.5, no: int = 2):
super().__init__()
self.prob = prob
self.no = no
interp = torchvision.transforms.InterpolationMode.NEAREST
fill = 1.0
self.augmentations = [
torchvision.transforms.RandomRotation(3.0, interpolation=interp, fill=fill),
torchvision.transforms.RandomAffine(0.0, translate=(0.05, 0.05), interpolation=interp, fill=fill),
Downsize(),
torchvision.transforms.ElasticTransform(alpha=10.0, sigma=7.0, fill=fill, interpolation=interp),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5),
torchvision.transforms.GaussianBlur(3, sigma=(0.1, 1.0)),
gauss_noise_tensor,
RandomMorphological(max_size=4, max_iterations=2, operation=cv2.MORPH_ERODE),
RandomMorphological(max_size=2, max_iterations=1, operation=cv2.MORPH_DILATE)
]
def forward(self, x):
if random.uniform(0.0, 1.0) > self.prob:
return x
augmentations = random.choices(self.augmentations, k=self.no)
for augmentation in augmentations:
x = augmentation(x)
return x
class WordCrop(torch.nn.Module, ABC):
def __init__(self, use_padding: bool = False):
super().__init__()
self.use_padding = use_padding
self.pad = torchvision.transforms.Pad([2, 2, 2, 2], 1.0)
@abstractmethod
def get_current_width(self):
pass
@abstractmethod
def update(self, epoch: int):
pass
def forward(self, images):
assert len(images.size()) == 4 and images.size(1) == 1, "Augmentation works on batches of one channel images"
if self.use_padding:
images = self.pad(images)
results = []
width = self.get_current_width()
for image in images:
index = compute_word_width(image)
max_index = max(min(index - width // 2, image.size(2) - width), 0)
start_index = random.randint(0, max_index)
results.append(F.crop(image, 0, start_index, image.size(1), min(width, image.size(2))))
return torch.stack(results)
class StaticWordCrop(WordCrop):
def __init__(self, width: int, use_padding: bool = False):
super().__init__(use_padding=use_padding)
self.width = width
def get_current_width(self):
return int(self.width)
def update(self, epoch: int):
pass
class RandomWordCrop(WordCrop):
def __init__(self, min_width: int, max_width: int, use_padding: bool = False):
super().__init__(use_padding)
self.min_width = min_width
self.max_width = max_width
self.current_width = random.randint(self.min_width, self.max_width)
def update(self, epoch: int):
self.current_width = random.randint(self.min_width, self.max_width)
def get_current_width(self):
return self.current_width
class FullCrop(torch.nn.Module):
def __init__(self, width: int):
super().__init__()
self.width = width
self.height = 32
self.pad = torchvision.transforms.Pad([6, 6, 6, 6], 1.0)
def get_current_width(self):
return self.width
def forward(self, images):
assert len(images.size()) == 4 and images.size(1) == 1, "Augmentation works on batches of one channel images"
images = self.pad(images)
results = []
for image in images:
index = compute_word_width(image)
max_index = max(min(index - self.width // 2, image.size(2) - self.width), 0)
start_width = random.randint(0, max_index)
start_height = random.randint(0, image.size(1) - self.height)
results.append(F.crop(image, start_height, start_width, self.height, min(self.width, image.size(2))))
return torch.stack(results)
class ProgressiveWordCrop(WordCrop):
def __init__(self, width: int, warmup_epochs: int, start_width: int = 128, use_padding: bool = False):
super().__init__(use_padding=use_padding)
self.target_width = width
self.warmup_epochs = warmup_epochs
self.start_width = start_width
self.current_width = float(start_width)
def update(self, epoch: int):
value = self.start_width - ((self.start_width - self.target_width) / self.warmup_epochs) * epoch
self.current_width = max(value, self.target_width)
def get_current_width(self):
return int(round(self.current_width))
class CycleWordCrop(WordCrop):
def __init__(self, width: int, cycle_epochs: int, start_width: int = 128, use_padding: bool = False):
super().__init__(use_padding=use_padding)
self.target_width = width
self.start_width = start_width
self.current_width = float(start_width)
self.cycle_epochs = float(cycle_epochs)
def update(self, epoch: int):
value = (math.cos((float(epoch) * 2 * math.pi) / self.cycle_epochs) + 1) * ((self.start_width - self.target_width) / 2) + self.target_width
self.current_width = value
def get_current_width(self):
return int(round(self.current_width))
class HeightResize(torch.nn.Module):
def __init__(self, target_height: int):
super().__init__()
self.target_height = target_height
def forward(self, x):
width, height = F.get_image_size(x)
scale = self.target_height / height
return F.resize(x, [int(height * scale), int(width * scale)])
def show_crops():
with open("../files/IAM-32-pa.pickle", 'rb') as f:
data = pickle.load(f)
for author in data['train'].keys():
for image in data['train'][author]:
image = torch.Tensor(np.expand_dims(np.expand_dims(np.array(image['img']), 0), 0))
augmenter = torchvision.transforms.Compose([
HeightResize(32),
FullCrop(128)
])
batch = augmenter(image)
batch = batch.detach().cpu().numpy()
result = [np.squeeze(im) for im in batch]
#plt.imshow(np.squeeze(image))
f, ax = plt.subplots(1, len(result))
for i in range(len(result)):
ax.imshow(result[i])
plt.show()
if __name__ == "__main__":
dataset = CollectionTextDataset(
'IAM', '../files', TextDataset, file_suffix='pa', num_examples=15,
collator_resolution=16, min_virtual_size=339, validation=False, debug=False
)
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=8,
shuffle=True,
pin_memory=True, drop_last=True,
collate_fn=dataset.collate_fn)
augmenter = OCRAugment(no=3, prob=1.0)
target_folder = r"C:\Users\bramv\Documents\Werk\Research\Unimore\VATr\VATr_ext\saved_images\debug\ocr_aug"
image_no = 0
for batch in train_loader:
for i in range(5):
augmented = augmenter(batch["img"])
img = np.squeeze((augmented[0].detach().cpu().numpy() + 1.0) / 2.0)
img = (img * 255.0).astype(np.uint8)
print(cv2.imwrite(os.path.join(target_folder, f"{image_no}_{i}.png"), img))
img = np.squeeze((batch["img"][0].detach().cpu().numpy() + 1.0) / 2.0)
img = (img * 255.0).astype(np.uint8)
cv2.imwrite(os.path.join(target_folder, f"{image_no}.png"), img)
if image_no > 5:
break
image_no+=1