myTest01 / models /util /array_util.py
meng2003's picture
Upload 85 files
bc32eea
import torch
import torch.nn as nn
import torch.nn.functional as F
class Flip(nn.Module):
def forward(self, x, cond, sldj, reverse=False):
assert isinstance(x, tuple) and len(x) == 2
return (x[1], x[0]), sldj
def mean_dim(tensor, dim=None, keepdims=False):
"""Take the mean along multiple dimensions.
Args:
tensor (torch.Tensor): Tensor of values to average.
dim (list): List of dimensions along which to take the mean.
keepdims (bool): Keep dimensions rather than squeezing.
Returns:
mean (torch.Tensor): New tensor of mean value(s).
"""
if dim is None:
return tensor.mean()
else:
if isinstance(dim, int):
dim = [dim]
dim = sorted(dim)
for d in dim:
tensor = tensor.mean(dim=d, keepdim=True)
if not keepdims:
for i, d in enumerate(dim):
tensor.squeeze_(d-i)
return tensor
def checkerboard(x, reverse=False):
"""Split x in a checkerboard pattern. Collapse horizontally."""
# Get dimensions
if reverse:
b, c, h, w = x[0].size()
w *= 2
device = x[0].device
else:
b, c, h, w = x.size()
device = x.device
# Get list of indices in alternating checkerboard pattern
y_idx = []
z_idx = []
for i in range(h):
for j in range(w):
if (i % 2) == (j % 2):
y_idx.append(i * w + j)
else:
z_idx.append(i * w + j)
y_idx = torch.tensor(y_idx, dtype=torch.int64, device=device)
z_idx = torch.tensor(z_idx, dtype=torch.int64, device=device)
if reverse:
y, z = (t.contiguous().view(b, c, h // 2 * w) for t in x)
x = torch.zeros(b, c, h * w, dtype=y.dtype, device=y.device)
x[:, :, y_idx] += y
x[:, :, z_idx] += z
x = x.view(b, c, h, w)
return x
else:
if h % 2 != 0:
raise RuntimeError('Checkerboard got odd height input: {}'.format(h))
x = x.view(b, c, h * w)
y = x[:, :, y_idx].view(b, c, h // 2, w)
z = x[:, :, z_idx].view(b, c, h // 2, w)
return y, z
def channelwise(x, reverse=False):
"""Split x channel-wise."""
if reverse:
x = torch.cat(x, dim=1)
return x
else:
y, z = x.chunk(2, dim=1)
return y, z
def squeeze(x):
"""Trade spatial extent for channels. I.e., convert each
1x4x4 volume of input into a 4x1x1 volume of output.
Args:
x (torch.Tensor): Input to squeeze.
Returns:
x (torch.Tensor): Squeezed or unsqueezed tensor.
"""
# import pdb; pdb.set_trace()
b, c, h, w = x.size()
x = x.view(b, c, h // 2, 2, w, 1)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(b, c * 2, h // 2, w)
return x
def unsqueeze(x):
"""Trade channels channels for spatial extent. I.e., convert each
4x1x1 volume of input into a 1x4x4 volume of output.
Args:
x (torch.Tensor): Input to unsqueeze.
Returns:
x (torch.Tensor): Unsqueezed tensor.
"""
b, c, h, w = x.size()
x = x.view(b, c // 2, 2, 1, h, w)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(b, c // 2, h * 2, w)
return x
def concat_elu(x):
"""Concatenated ReLU (http://arxiv.org/abs/1603.05201), but with ELU."""
return F.elu(torch.cat((x, -x), dim=1))
def safe_log(x):
return torch.log(x.clamp(min=1e-22))