|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import scipy.signal |
|
import torch |
|
from torch_utils import persistence |
|
from torch_utils import misc |
|
from torch_utils.ops import upfirdn2d |
|
from torch_utils.ops import grid_sample_gradfix |
|
from torch_utils.ops import conv2d_gradfix |
|
|
|
|
|
|
|
|
|
wavelets = { |
|
"haar": [0.7071067811865476, 0.7071067811865476], |
|
"db1": [0.7071067811865476, 0.7071067811865476], |
|
"db2": [ |
|
-0.12940952255092145, |
|
0.22414386804185735, |
|
0.836516303737469, |
|
0.48296291314469025, |
|
], |
|
"db3": [ |
|
0.035226291882100656, |
|
-0.08544127388224149, |
|
-0.13501102001039084, |
|
0.4598775021193313, |
|
0.8068915093133388, |
|
0.3326705529509569, |
|
], |
|
"db4": [ |
|
-0.010597401784997278, |
|
0.032883011666982945, |
|
0.030841381835986965, |
|
-0.18703481171888114, |
|
-0.02798376941698385, |
|
0.6308807679295904, |
|
0.7148465705525415, |
|
0.23037781330885523, |
|
], |
|
"db5": [ |
|
0.003335725285001549, |
|
-0.012580751999015526, |
|
-0.006241490213011705, |
|
0.07757149384006515, |
|
-0.03224486958502952, |
|
-0.24229488706619015, |
|
0.13842814590110342, |
|
0.7243085284385744, |
|
0.6038292697974729, |
|
0.160102397974125, |
|
], |
|
"db6": [ |
|
-0.00107730108499558, |
|
0.004777257511010651, |
|
0.0005538422009938016, |
|
-0.031582039318031156, |
|
0.02752286553001629, |
|
0.09750160558707936, |
|
-0.12976686756709563, |
|
-0.22626469396516913, |
|
0.3152503517092432, |
|
0.7511339080215775, |
|
0.4946238903983854, |
|
0.11154074335008017, |
|
], |
|
"db7": [ |
|
0.0003537138000010399, |
|
-0.0018016407039998328, |
|
0.00042957797300470274, |
|
0.012550998556013784, |
|
-0.01657454163101562, |
|
-0.03802993693503463, |
|
0.0806126091510659, |
|
0.07130921926705004, |
|
-0.22403618499416572, |
|
-0.14390600392910627, |
|
0.4697822874053586, |
|
0.7291320908465551, |
|
0.39653931948230575, |
|
0.07785205408506236, |
|
], |
|
"db8": [ |
|
-0.00011747678400228192, |
|
0.0006754494059985568, |
|
-0.0003917403729959771, |
|
-0.00487035299301066, |
|
0.008746094047015655, |
|
0.013981027917015516, |
|
-0.04408825393106472, |
|
-0.01736930100202211, |
|
0.128747426620186, |
|
0.00047248457399797254, |
|
-0.2840155429624281, |
|
-0.015829105256023893, |
|
0.5853546836548691, |
|
0.6756307362980128, |
|
0.3128715909144659, |
|
0.05441584224308161, |
|
], |
|
"sym2": [ |
|
-0.12940952255092145, |
|
0.22414386804185735, |
|
0.836516303737469, |
|
0.48296291314469025, |
|
], |
|
"sym3": [ |
|
0.035226291882100656, |
|
-0.08544127388224149, |
|
-0.13501102001039084, |
|
0.4598775021193313, |
|
0.8068915093133388, |
|
0.3326705529509569, |
|
], |
|
"sym4": [ |
|
-0.07576571478927333, |
|
-0.02963552764599851, |
|
0.49761866763201545, |
|
0.8037387518059161, |
|
0.29785779560527736, |
|
-0.09921954357684722, |
|
-0.012603967262037833, |
|
0.0322231006040427, |
|
], |
|
"sym5": [ |
|
0.027333068345077982, |
|
0.029519490925774643, |
|
-0.039134249302383094, |
|
0.1993975339773936, |
|
0.7234076904024206, |
|
0.6339789634582119, |
|
0.01660210576452232, |
|
-0.17532808990845047, |
|
-0.021101834024758855, |
|
0.019538882735286728, |
|
], |
|
"sym6": [ |
|
0.015404109327027373, |
|
0.0034907120842174702, |
|
-0.11799011114819057, |
|
-0.048311742585633, |
|
0.4910559419267466, |
|
0.787641141030194, |
|
0.3379294217276218, |
|
-0.07263752278646252, |
|
-0.021060292512300564, |
|
0.04472490177066578, |
|
0.0017677118642428036, |
|
-0.007800708325034148, |
|
], |
|
"sym7": [ |
|
0.002681814568257878, |
|
-0.0010473848886829163, |
|
-0.01263630340325193, |
|
0.03051551316596357, |
|
0.0678926935013727, |
|
-0.049552834937127255, |
|
0.017441255086855827, |
|
0.5361019170917628, |
|
0.767764317003164, |
|
0.2886296317515146, |
|
-0.14004724044296152, |
|
-0.10780823770381774, |
|
0.004010244871533663, |
|
0.010268176708511255, |
|
], |
|
"sym8": [ |
|
-0.0033824159510061256, |
|
-0.0005421323317911481, |
|
0.03169508781149298, |
|
0.007607487324917605, |
|
-0.1432942383508097, |
|
-0.061273359067658524, |
|
0.4813596512583722, |
|
0.7771857517005235, |
|
0.3644418948353314, |
|
-0.05194583810770904, |
|
-0.027219029917056003, |
|
0.049137179673607506, |
|
0.003808752013890615, |
|
-0.01495225833704823, |
|
-0.0003029205147213668, |
|
0.0018899503327594609, |
|
], |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def matrix(*rows, device=None): |
|
assert all(len(row) == len(rows[0]) for row in rows) |
|
elems = [x for row in rows for x in row] |
|
ref = [x for x in elems if isinstance(x, torch.Tensor)] |
|
if len(ref) == 0: |
|
return misc.constant(np.asarray(rows), device=device) |
|
assert device is None or device == ref[0].device |
|
elems = [ |
|
x |
|
if isinstance(x, torch.Tensor) |
|
else misc.constant(x, shape=ref[0].shape, device=ref[0].device) |
|
for x in elems |
|
] |
|
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) |
|
|
|
|
|
def translate2d(tx, ty, **kwargs): |
|
return matrix([1, 0, tx], [0, 1, ty], [0, 0, 1], **kwargs) |
|
|
|
|
|
def translate3d(tx, ty, tz, **kwargs): |
|
return matrix([1, 0, 0, tx], [0, 1, 0, ty], [0, 0, 1, tz], [0, 0, 0, 1], **kwargs) |
|
|
|
|
|
def scale2d(sx, sy, **kwargs): |
|
return matrix([sx, 0, 0], [0, sy, 0], [0, 0, 1], **kwargs) |
|
|
|
|
|
def scale3d(sx, sy, sz, **kwargs): |
|
return matrix([sx, 0, 0, 0], [0, sy, 0, 0], [0, 0, sz, 0], [0, 0, 0, 1], **kwargs) |
|
|
|
|
|
def rotate2d(theta, **kwargs): |
|
return matrix( |
|
[torch.cos(theta), torch.sin(-theta), 0], |
|
[torch.sin(theta), torch.cos(theta), 0], |
|
[0, 0, 1], |
|
**kwargs |
|
) |
|
|
|
|
|
def rotate3d(v, theta, **kwargs): |
|
vx = v[..., 0] |
|
vy = v[..., 1] |
|
vz = v[..., 2] |
|
s = torch.sin(theta) |
|
c = torch.cos(theta) |
|
cc = 1 - c |
|
return matrix( |
|
[vx * vx * cc + c, vx * vy * cc - vz * s, vx * vz * cc + vy * s, 0], |
|
[vy * vx * cc + vz * s, vy * vy * cc + c, vy * vz * cc - vx * s, 0], |
|
[vz * vx * cc - vy * s, vz * vy * cc + vx * s, vz * vz * cc + c, 0], |
|
[0, 0, 0, 1], |
|
**kwargs |
|
) |
|
|
|
|
|
def translate2d_inv(tx, ty, **kwargs): |
|
return translate2d(-tx, -ty, **kwargs) |
|
|
|
|
|
def scale2d_inv(sx, sy, **kwargs): |
|
return scale2d(1 / sx, 1 / sy, **kwargs) |
|
|
|
|
|
def rotate2d_inv(theta, **kwargs): |
|
return rotate2d(-theta, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@persistence.persistent_class |
|
class AugmentPipe(torch.nn.Module): |
|
def __init__( |
|
self, |
|
xflip=0, |
|
rotate90=0, |
|
xint=0, |
|
xint_max=0.125, |
|
scale=0, |
|
rotate=0, |
|
aniso=0, |
|
xfrac=0, |
|
scale_std=0.2, |
|
rotate_max=1, |
|
aniso_std=0.2, |
|
xfrac_std=0.125, |
|
brightness=0, |
|
contrast=0, |
|
lumaflip=0, |
|
hue=0, |
|
saturation=0, |
|
brightness_std=0.2, |
|
contrast_std=0.5, |
|
hue_max=1, |
|
saturation_std=1, |
|
imgfilter=0, |
|
imgfilter_bands=[1, 1, 1, 1], |
|
imgfilter_std=1, |
|
noise=0, |
|
cutout=0, |
|
noise_std=0.1, |
|
cutout_size=0.5, |
|
): |
|
super().__init__() |
|
self.register_buffer( |
|
"p", torch.ones([]) |
|
) |
|
|
|
|
|
self.xflip = float(xflip) |
|
self.rotate90 = float( |
|
rotate90 |
|
) |
|
self.xint = float(xint) |
|
self.xint_max = float( |
|
xint_max |
|
) |
|
|
|
|
|
self.scale = float(scale) |
|
self.rotate = float(rotate) |
|
self.aniso = float(aniso) |
|
self.xfrac = float(xfrac) |
|
self.scale_std = float( |
|
scale_std |
|
) |
|
self.rotate_max = float( |
|
rotate_max |
|
) |
|
self.aniso_std = float( |
|
aniso_std |
|
) |
|
self.xfrac_std = float( |
|
xfrac_std |
|
) |
|
|
|
|
|
self.brightness = float(brightness) |
|
self.contrast = float(contrast) |
|
self.lumaflip = float(lumaflip) |
|
self.hue = float(hue) |
|
self.saturation = float(saturation) |
|
self.brightness_std = float(brightness_std) |
|
self.contrast_std = float(contrast_std) |
|
self.hue_max = float(hue_max) |
|
self.saturation_std = float( |
|
saturation_std |
|
) |
|
|
|
|
|
self.imgfilter = float( |
|
imgfilter |
|
) |
|
self.imgfilter_bands = list( |
|
imgfilter_bands |
|
) |
|
self.imgfilter_std = float( |
|
imgfilter_std |
|
) |
|
|
|
|
|
self.noise = float(noise) |
|
self.cutout = float(cutout) |
|
self.noise_std = float(noise_std) |
|
self.cutout_size = float( |
|
cutout_size |
|
) |
|
|
|
|
|
self.register_buffer("Hz_geom", upfirdn2d.setup_filter(wavelets["sym6"])) |
|
|
|
|
|
Hz_lo = np.asarray(wavelets["sym2"]) |
|
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) |
|
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 |
|
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 |
|
Hz_fbank = np.eye(4, 1) |
|
for i in range(1, Hz_fbank.shape[0]): |
|
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape( |
|
Hz_fbank.shape[0], -1 |
|
)[:, :-1] |
|
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) |
|
Hz_fbank[ |
|
i, |
|
(Hz_fbank.shape[1] - Hz_hi2.size) |
|
// 2 : (Hz_fbank.shape[1] + Hz_hi2.size) |
|
// 2, |
|
] += Hz_hi2 |
|
self.register_buffer("Hz_fbank", torch.as_tensor(Hz_fbank, dtype=torch.float32)) |
|
|
|
def forward(self, images, debug_percentile=None): |
|
assert isinstance(images, torch.Tensor) and images.ndim == 4 |
|
batch_size, num_channels, height, width = images.shape |
|
device = images.device |
|
if debug_percentile is not None: |
|
debug_percentile = torch.as_tensor( |
|
debug_percentile, dtype=torch.float32, device=device |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
I_3 = torch.eye(3, device=device) |
|
G_inv = I_3 |
|
|
|
|
|
if self.xflip > 0: |
|
i = torch.floor(torch.rand([batch_size], device=device) * 2) |
|
i = torch.where( |
|
torch.rand([batch_size], device=device) < self.xflip * self.p, |
|
i, |
|
torch.zeros_like(i), |
|
) |
|
if debug_percentile is not None: |
|
i = torch.full_like(i, torch.floor(debug_percentile * 2)) |
|
G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) |
|
|
|
|
|
if self.rotate90 > 0: |
|
i = torch.floor(torch.rand([batch_size], device=device) * 4) |
|
i = torch.where( |
|
torch.rand([batch_size], device=device) < self.rotate90 * self.p, |
|
i, |
|
torch.zeros_like(i), |
|
) |
|
if debug_percentile is not None: |
|
i = torch.full_like(i, torch.floor(debug_percentile * 4)) |
|
G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) |
|
|
|
|
|
if self.xint > 0: |
|
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max |
|
t = torch.where( |
|
torch.rand([batch_size, 1], device=device) < self.xint * self.p, |
|
t, |
|
torch.zeros_like(t), |
|
) |
|
if debug_percentile is not None: |
|
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) |
|
G_inv = G_inv @ translate2d_inv( |
|
torch.round(t[:, 0] * width), torch.round(t[:, 1] * height) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.scale > 0: |
|
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) |
|
s = torch.where( |
|
torch.rand([batch_size], device=device) < self.scale * self.p, |
|
s, |
|
torch.ones_like(s), |
|
) |
|
if debug_percentile is not None: |
|
s = torch.full_like( |
|
s, |
|
torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std), |
|
) |
|
G_inv = G_inv @ scale2d_inv(s, s) |
|
|
|
|
|
p_rot = 1 - torch.sqrt( |
|
(1 - self.rotate * self.p).clamp(0, 1) |
|
) |
|
if self.rotate > 0: |
|
theta = ( |
|
(torch.rand([batch_size], device=device) * 2 - 1) |
|
* np.pi |
|
* self.rotate_max |
|
) |
|
theta = torch.where( |
|
torch.rand([batch_size], device=device) < p_rot, |
|
theta, |
|
torch.zeros_like(theta), |
|
) |
|
if debug_percentile is not None: |
|
theta = torch.full_like( |
|
theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max |
|
) |
|
G_inv = G_inv @ rotate2d_inv(-theta) |
|
|
|
|
|
if self.aniso > 0: |
|
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) |
|
s = torch.where( |
|
torch.rand([batch_size], device=device) < self.aniso * self.p, |
|
s, |
|
torch.ones_like(s), |
|
) |
|
if debug_percentile is not None: |
|
s = torch.full_like( |
|
s, |
|
torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std), |
|
) |
|
G_inv = G_inv @ scale2d_inv(s, 1 / s) |
|
|
|
|
|
if self.rotate > 0: |
|
theta = ( |
|
(torch.rand([batch_size], device=device) * 2 - 1) |
|
* np.pi |
|
* self.rotate_max |
|
) |
|
theta = torch.where( |
|
torch.rand([batch_size], device=device) < p_rot, |
|
theta, |
|
torch.zeros_like(theta), |
|
) |
|
if debug_percentile is not None: |
|
theta = torch.zeros_like(theta) |
|
G_inv = G_inv @ rotate2d_inv(-theta) |
|
|
|
|
|
if self.xfrac > 0: |
|
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std |
|
t = torch.where( |
|
torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, |
|
t, |
|
torch.zeros_like(t), |
|
) |
|
if debug_percentile is not None: |
|
t = torch.full_like( |
|
t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std |
|
) |
|
G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if G_inv is not I_3: |
|
|
|
|
|
cx = (width - 1) / 2 |
|
cy = (height - 1) / 2 |
|
cp = matrix( |
|
[-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device |
|
) |
|
cp = G_inv @ cp.t() |
|
Hz_pad = self.Hz_geom.shape[0] // 4 |
|
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) |
|
margin = torch.cat([-margin, margin]).max(dim=1).values |
|
margin = margin + misc.constant( |
|
[Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device |
|
) |
|
margin = margin.max(misc.constant([0, 0] * 2, device=device)) |
|
margin = margin.min( |
|
misc.constant([width - 1, height - 1] * 2, device=device) |
|
) |
|
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) |
|
|
|
|
|
images = torch.nn.functional.pad( |
|
input=images, pad=[mx0, mx1, my0, my1], mode="reflect" |
|
) |
|
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv |
|
|
|
|
|
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) |
|
G_inv = ( |
|
scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) |
|
) |
|
G_inv = ( |
|
translate2d(-0.5, -0.5, device=device) |
|
@ G_inv |
|
@ translate2d_inv(-0.5, -0.5, device=device) |
|
) |
|
|
|
|
|
shape = [ |
|
batch_size, |
|
num_channels, |
|
(height + Hz_pad * 2) * 2, |
|
(width + Hz_pad * 2) * 2, |
|
] |
|
G_inv = ( |
|
scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) |
|
@ G_inv |
|
@ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) |
|
) |
|
grid = torch.nn.functional.affine_grid( |
|
theta=G_inv[:, :2, :], size=shape, align_corners=False |
|
) |
|
images = grid_sample_gradfix.grid_sample(images, grid) |
|
|
|
|
|
images = upfirdn2d.downsample2d( |
|
x=images, f=self.Hz_geom, down=2, padding=-Hz_pad * 2, flip_filter=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
I_4 = torch.eye(4, device=device) |
|
C = I_4 |
|
|
|
|
|
if self.brightness > 0: |
|
b = torch.randn([batch_size], device=device) * self.brightness_std |
|
b = torch.where( |
|
torch.rand([batch_size], device=device) < self.brightness * self.p, |
|
b, |
|
torch.zeros_like(b), |
|
) |
|
if debug_percentile is not None: |
|
b = torch.full_like( |
|
b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std |
|
) |
|
C = translate3d(b, b, b) @ C |
|
|
|
|
|
if self.contrast > 0: |
|
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) |
|
c = torch.where( |
|
torch.rand([batch_size], device=device) < self.contrast * self.p, |
|
c, |
|
torch.ones_like(c), |
|
) |
|
if debug_percentile is not None: |
|
c = torch.full_like( |
|
c, |
|
torch.exp2( |
|
torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std |
|
), |
|
) |
|
C = scale3d(c, c, c) @ C |
|
|
|
|
|
v = misc.constant( |
|
np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device |
|
) |
|
if self.lumaflip > 0: |
|
i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) |
|
i = torch.where( |
|
torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, |
|
i, |
|
torch.zeros_like(i), |
|
) |
|
if debug_percentile is not None: |
|
i = torch.full_like(i, torch.floor(debug_percentile * 2)) |
|
C = (I_4 - 2 * v.ger(v) * i) @ C |
|
|
|
|
|
if self.hue > 0 and num_channels > 1: |
|
theta = ( |
|
(torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max |
|
) |
|
theta = torch.where( |
|
torch.rand([batch_size], device=device) < self.hue * self.p, |
|
theta, |
|
torch.zeros_like(theta), |
|
) |
|
if debug_percentile is not None: |
|
theta = torch.full_like( |
|
theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max |
|
) |
|
C = rotate3d(v, theta) @ C |
|
|
|
|
|
if self.saturation > 0 and num_channels > 1: |
|
s = torch.exp2( |
|
torch.randn([batch_size, 1, 1], device=device) * self.saturation_std |
|
) |
|
s = torch.where( |
|
torch.rand([batch_size, 1, 1], device=device) |
|
< self.saturation * self.p, |
|
s, |
|
torch.ones_like(s), |
|
) |
|
if debug_percentile is not None: |
|
s = torch.full_like( |
|
s, |
|
torch.exp2( |
|
torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std |
|
), |
|
) |
|
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C |
|
|
|
|
|
|
|
|
|
|
|
|
|
if C is not I_4: |
|
images = images.reshape([batch_size, num_channels, height * width]) |
|
if num_channels == 3: |
|
images = C[:, :3, :3] @ images + C[:, :3, 3:] |
|
elif num_channels == 1: |
|
C = C[:, :3, :].mean(dim=1, keepdims=True) |
|
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] |
|
else: |
|
raise ValueError("Image must be RGB (3 channels) or L (1 channel)") |
|
images = images.reshape([batch_size, num_channels, height, width]) |
|
|
|
|
|
|
|
|
|
|
|
if self.imgfilter > 0: |
|
num_bands = self.Hz_fbank.shape[0] |
|
assert len(self.imgfilter_bands) == num_bands |
|
expected_power = misc.constant( |
|
np.array([10, 1, 1, 1]) / 13, device=device |
|
) |
|
|
|
|
|
g = torch.ones( |
|
[batch_size, num_bands], device=device |
|
) |
|
for i, band_strength in enumerate(self.imgfilter_bands): |
|
t_i = torch.exp2( |
|
torch.randn([batch_size], device=device) * self.imgfilter_std |
|
) |
|
t_i = torch.where( |
|
torch.rand([batch_size], device=device) |
|
< self.imgfilter * self.p * band_strength, |
|
t_i, |
|
torch.ones_like(t_i), |
|
) |
|
if debug_percentile is not None: |
|
t_i = ( |
|
torch.full_like( |
|
t_i, |
|
torch.exp2( |
|
torch.erfinv(debug_percentile * 2 - 1) |
|
* self.imgfilter_std |
|
), |
|
) |
|
if band_strength > 0 |
|
else torch.ones_like(t_i) |
|
) |
|
t = torch.ones( |
|
[batch_size, num_bands], device=device |
|
) |
|
t[:, i] = t_i |
|
t = ( |
|
t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() |
|
) |
|
g = g * t |
|
|
|
|
|
Hz_prime = g @ self.Hz_fbank |
|
Hz_prime = Hz_prime.unsqueeze(1).repeat( |
|
[1, num_channels, 1] |
|
) |
|
Hz_prime = Hz_prime.reshape( |
|
[batch_size * num_channels, 1, -1] |
|
) |
|
|
|
|
|
p = self.Hz_fbank.shape[1] // 2 |
|
images = images.reshape([1, batch_size * num_channels, height, width]) |
|
images = torch.nn.functional.pad( |
|
input=images, pad=[p, p, p, p], mode="reflect" |
|
) |
|
images = conv2d_gradfix.conv2d( |
|
input=images, |
|
weight=Hz_prime.unsqueeze(2), |
|
groups=batch_size * num_channels, |
|
) |
|
images = conv2d_gradfix.conv2d( |
|
input=images, |
|
weight=Hz_prime.unsqueeze(3), |
|
groups=batch_size * num_channels, |
|
) |
|
images = images.reshape([batch_size, num_channels, height, width]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.noise > 0: |
|
sigma = ( |
|
torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std |
|
) |
|
sigma = torch.where( |
|
torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, |
|
sigma, |
|
torch.zeros_like(sigma), |
|
) |
|
if debug_percentile is not None: |
|
sigma = torch.full_like( |
|
sigma, torch.erfinv(debug_percentile) * self.noise_std |
|
) |
|
images = ( |
|
images |
|
+ torch.randn([batch_size, num_channels, height, width], device=device) |
|
* sigma |
|
) |
|
|
|
|
|
if self.cutout > 0: |
|
size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) |
|
size = torch.where( |
|
torch.rand([batch_size, 1, 1, 1, 1], device=device) |
|
< self.cutout * self.p, |
|
size, |
|
torch.zeros_like(size), |
|
) |
|
center = torch.rand([batch_size, 2, 1, 1, 1], device=device) |
|
if debug_percentile is not None: |
|
size = torch.full_like(size, self.cutout_size) |
|
center = torch.full_like(center, debug_percentile) |
|
coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) |
|
coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) |
|
mask_x = ((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2 |
|
mask_y = ((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2 |
|
mask = torch.logical_or(mask_x, mask_y).to(torch.float32) |
|
images = images * mask |
|
|
|
return images |
|
|
|
|
|
|
|
|