mfrashad's picture
Init code
8f87579
import math
import torch
from torch.nn import functional as F
def translate_mat(t_x, t_y):
batch = t_x.shape[0]
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
translate = torch.stack((t_x, t_y), 1)
mat[:, :2, 2] = translate
return mat
def rotate_mat(theta):
batch = theta.shape[0]
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
sin_t = torch.sin(theta)
cos_t = torch.cos(theta)
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
mat[:, :2, :2] = rot
return mat
def scale_mat(s_x, s_y):
batch = s_x.shape[0]
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
mat[:, 0, 0] = s_x
mat[:, 1, 1] = s_y
return mat
def lognormal_sample(size, mean=0, std=1):
return torch.empty(size).log_normal_(mean=mean, std=std)
def category_sample(size, categories):
category = torch.tensor(categories)
sample = torch.randint(high=len(categories), size=(size,))
return category[sample]
def uniform_sample(size, low, high):
return torch.empty(size).uniform_(low, high)
def normal_sample(size, mean=0, std=1):
return torch.empty(size).normal_(mean, std)
def bernoulli_sample(size, p):
return torch.empty(size).bernoulli_(p)
def random_affine_apply(p, transform, prev, eye):
size = transform.shape[0]
select = bernoulli_sample(size, p).view(size, 1, 1)
select_transform = select * transform + (1 - select) * eye
return select_transform @ prev
def sample_affine(p, size, height, width):
G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1)
eye = G
# flip
param = category_sample(size, (0, 1))
Gc = scale_mat(1 - 2.0 * param, torch.ones(size))
G = random_affine_apply(p, Gc, G, eye)
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
# 90 rotate
param = category_sample(size, (0, 3))
Gc = rotate_mat(-math.pi / 2 * param)
G = random_affine_apply(p, Gc, G, eye)
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
# integer translate
param = uniform_sample(size, -0.125, 0.125)
param_height = torch.round(param * height) / height
param_width = torch.round(param * width) / width
Gc = translate_mat(param_width, param_height)
G = random_affine_apply(p, Gc, G, eye)
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
# isotropic scale
param = lognormal_sample(size, std=0.2 * math.log(2))
Gc = scale_mat(param, param)
G = random_affine_apply(p, Gc, G, eye)
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
p_rot = 1 - math.sqrt(1 - p)
# pre-rotate
param = uniform_sample(size, -math.pi, math.pi)
Gc = rotate_mat(-param)
G = random_affine_apply(p_rot, Gc, G, eye)
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
# anisotropic scale
param = lognormal_sample(size, std=0.2 * math.log(2))
Gc = scale_mat(param, 1 / param)
G = random_affine_apply(p, Gc, G, eye)
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
# post-rotate
param = uniform_sample(size, -math.pi, math.pi)
Gc = rotate_mat(-param)
G = random_affine_apply(p_rot, Gc, G, eye)
# print('post-rotate', G, rotate_mat(-param), sep='\n')
# fractional translate
param = normal_sample(size, std=0.125)
Gc = translate_mat(param, param)
G = random_affine_apply(p, Gc, G, eye)
# print('fractional translate', G, translate_mat(param, param), sep='\n')
return G
def apply_affine(img, G):
grid = F.affine_grid(
torch.inverse(G).to(img)[:, :2, :], img.shape, align_corners=False
)
img_affine = F.grid_sample(
img, grid, mode="bilinear", align_corners=False, padding_mode="reflection"
)
return img_affine