|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.init import xavier_uniform |
|
|
|
class KAE(nn.Module): |
|
|
|
def __init__(self, w=32, h=32, c=1, nb_hidden=300, nb_active=16): |
|
super().__init__() |
|
self.nb_hidden = nb_hidden |
|
self.nb_active = nb_active |
|
self.encode = nn.Sequential( |
|
nn.Linear(w*h*c, nb_hidden, bias=False) |
|
) |
|
self.bias = nn.Parameter(torch.zeros(w*h*c)) |
|
self.params = nn.ParameterList([self.bias]) |
|
self.apply(_weights_init) |
|
|
|
def forward(self, X): |
|
size = X.size() |
|
X = X.view(X.size(0), -1) |
|
h = self.encode(X) |
|
Xr, _ = self.decode(h) |
|
Xr = Xr.view(size) |
|
return Xr |
|
|
|
def decode(self, h): |
|
thetas, _ = torch.sort(h, dim=1, descending=True) |
|
thetas = thetas[:, self.nb_active:self.nb_active+1] |
|
h = h * (h > thetas).float() |
|
Xr = torch.matmul(h, self.encode[0].weight) + self.bias |
|
Xr = nn.Sigmoid()(Xr) |
|
return Xr, h |
|
|
|
|
|
class ZAE(nn.Module): |
|
|
|
def __init__(self, w=32, h=32, c=1, nb_hidden=300, theta=1): |
|
super().__init__() |
|
self.nb_hidden = nb_hidden |
|
self.theta = theta |
|
self.encode = nn.Sequential( |
|
nn.Linear(w*h*c, nb_hidden, bias=False) |
|
) |
|
self.bias = nn.Parameter(torch.zeros(w*h*c)) |
|
self.params = nn.ParameterList([self.bias]) |
|
self.apply(_weights_init) |
|
|
|
def forward(self, X): |
|
size = X.size() |
|
X = X.view(X.size(0), -1) |
|
h = self.encode(X) |
|
Xr, _ = self.decode(h) |
|
Xr = Xr.view(size) |
|
return Xr |
|
|
|
def decode(self, h): |
|
h = h * (h > self.theta).float() |
|
Xr = torch.matmul(h, self.encode[0].weight) + self.bias |
|
Xr = nn.Sigmoid()(Xr) |
|
return Xr, h |
|
|
|
|
|
|
|
class DenseAE(nn.Module): |
|
|
|
def __init__(self, w=32, h=32, c=1, encode_hidden=(300,), decode_hidden=(300,), ksparse=True, nb_active=10, denoise=None): |
|
super().__init__() |
|
self.encode_hidden = encode_hidden |
|
self.decode_hidden = decode_hidden |
|
self.ksparse = ksparse |
|
self.nb_active = nb_active |
|
self.denoise = denoise |
|
|
|
|
|
layers = [] |
|
hid_prev = w * h * c |
|
for hid in encode_hidden: |
|
layers.extend([ |
|
nn.Linear(hid_prev, hid), |
|
nn.ReLU(True) |
|
]) |
|
hid_prev = hid |
|
self.encode = nn.Sequential(*layers) |
|
|
|
|
|
layers = [] |
|
for hid in decode_hidden: |
|
layers.extend([ |
|
nn.Linear(hid_prev, hid), |
|
nn.ReLU(True) |
|
]) |
|
hid_prev = hid |
|
layers.extend([ |
|
nn.Linear(hid_prev, w * h * c), |
|
nn.Sigmoid() |
|
]) |
|
self.decode = nn.Sequential(*layers) |
|
|
|
self.apply(_weights_init) |
|
|
|
def forward(self, X): |
|
size = X.size() |
|
if self.denoise is not None: |
|
X = X * ((torch.rand(X.size()) <= self.denoise).float()).to(X.device) |
|
X = X.view(X.size(0), -1) |
|
h = self.encode(X) |
|
if self.ksparse: |
|
h = ksparse(h, nb_active=self.nb_active) |
|
Xr = self.decode(h) |
|
Xr = Xr.view(size) |
|
return Xr |
|
|
|
|
|
|
|
def ksparse(x, nb_active=10): |
|
mask = torch.ones(x.size()) |
|
for i, xi in enumerate(x.data.tolist()): |
|
inds = np.argsort(xi) |
|
inds = inds[::-1] |
|
inds = inds[nb_active:] |
|
if len(inds): |
|
inds = np.array(inds) |
|
inds = torch.from_numpy(inds).long() |
|
mask[i][inds] = 0 |
|
return x * (mask).float().to(x.device) |
|
|
|
|
|
class ConvAE(nn.Module): |
|
|
|
def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4): |
|
super().__init__() |
|
self.spatial = spatial |
|
self.channel = channel |
|
self.channel_stride = channel_stride |
|
self.encode = nn.Sequential( |
|
nn.Conv2d(c, nb_filters, 5, 1, 0), |
|
nn.ReLU(True), |
|
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0), |
|
nn.ReLU(True), |
|
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0), |
|
) |
|
self.decode = nn.Sequential( |
|
nn.ConvTranspose2d(nb_filters, c, 13, 1, 0), |
|
nn.Sigmoid() |
|
) |
|
self.apply(_weights_init) |
|
|
|
def forward(self, X): |
|
size = X.size() |
|
h = self.encode(X) |
|
h = self.sparsify(h) |
|
Xr = self.decode(h) |
|
return Xr |
|
|
|
def sparsify(self, h): |
|
if self.spatial: |
|
h = spatial_sparsity(h) |
|
if self.channel: |
|
h = strided_channel_sparsity(h, stride=self.channel_stride) |
|
return h |
|
|
|
class SimpleConvAE(nn.Module): |
|
|
|
def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4): |
|
super().__init__() |
|
self.spatial = spatial |
|
self.channel = channel |
|
self.channel_stride = channel_stride |
|
self.encode = nn.Sequential( |
|
nn.Conv2d(c, nb_filters, 13, 1, 0), |
|
nn.ReLU(True), |
|
) |
|
self.decode = nn.Sequential( |
|
nn.ConvTranspose2d(nb_filters, c, 13, 1, 0), |
|
nn.Sigmoid() |
|
) |
|
self.apply(_weights_init) |
|
|
|
def forward(self, X): |
|
size = X.size() |
|
h = self.encode(X) |
|
h = self.sparsify(h) |
|
Xr = self.decode(h) |
|
return Xr |
|
|
|
def sparsify(self, h): |
|
if self.spatial: |
|
h = spatial_sparsity(h) |
|
if self.channel: |
|
h = strided_channel_sparsity(h, stride=self.channel_stride) |
|
return h |
|
|
|
class DeepConvAE(nn.Module): |
|
|
|
def __init__(self, w=32, h=32, c=1, nb_filters=64, nb_layers=3, spatial=True, channel=True, channel_stride=4): |
|
super().__init__() |
|
self.spatial = spatial |
|
self.channel = channel |
|
self.channel_stride = channel_stride |
|
|
|
layers = [ |
|
nn.Conv2d(c, nb_filters, 5, 1, 0), |
|
nn.ReLU(True), |
|
] |
|
for _ in range(nb_layers - 1): |
|
layers.extend([ |
|
nn.Conv2d(nb_filters, nb_filters, 5, 1, 0), |
|
nn.ReLU(True), |
|
]) |
|
self.encode = nn.Sequential(*layers) |
|
layers = [] |
|
for _ in range(nb_layers - 1): |
|
layers.extend([ |
|
nn.ConvTranspose2d(nb_filters, nb_filters, 5, 1, 0), |
|
nn.ReLU(True), |
|
]) |
|
layers.extend([ |
|
nn.ConvTranspose2d(nb_filters, c, 5, 1, 0), |
|
nn.Sigmoid() |
|
]) |
|
self.decode = nn.Sequential(*layers) |
|
self.apply(_weights_init) |
|
|
|
def forward(self, X): |
|
size = X.size() |
|
h = self.encode(X) |
|
h = self.sparsify(h) |
|
Xr = self.decode(h) |
|
return Xr |
|
|
|
def sparsify(self, h): |
|
if self.spatial: |
|
h = spatial_sparsity(h) |
|
if self.channel: |
|
h = strided_channel_sparsity(h, stride=self.channel_stride) |
|
return h |
|
|
|
|
|
def spatial_sparsity(x): |
|
maxes = x.amax(dim=(2,3), keepdims=True) |
|
return x * equals(x, maxes) |
|
|
|
def equals(x, y, eps=1e-8): |
|
return torch.abs(x-y) <= eps |
|
|
|
def strided_channel_sparsity(x, stride=1): |
|
B, F = x.shape[0:2] |
|
h, w = x.shape[2:] |
|
x_ = x.view(B, F, h // stride, stride, w // stride, stride) |
|
mask = equals(x_, x_.amax(axis=(1, 3, 5), keepdims=True)) |
|
mask = mask.view(x.shape).float() |
|
return x * mask |
|
|
|
|
|
def _weights_init(m): |
|
if hasattr(m, 'weight'): |
|
xavier_uniform(m.weight.data) |
|
if m.bias is not None: |
|
m.bias.data.fill_(0) |
|
|