|
import torch |
|
import torch.utils.data |
|
import torch.nn as nn |
|
|
|
def get_model(params): |
|
|
|
if params['model'] == 'ResidualFCNet': |
|
return ResidualFCNet(params['input_dim'], params['num_classes'], params['num_filts'], params['depth']) |
|
elif params['model'] == 'LinNet': |
|
return LinNet(params['input_dim'], params['num_classes']) |
|
else: |
|
raise NotImplementedError('Invalid model specified.') |
|
|
|
class ResLayer(nn.Module): |
|
def __init__(self, linear_size): |
|
super(ResLayer, self).__init__() |
|
self.l_size = linear_size |
|
self.nonlin1 = nn.ReLU(inplace=True) |
|
self.nonlin2 = nn.ReLU(inplace=True) |
|
self.dropout1 = nn.Dropout() |
|
self.w1 = nn.Linear(self.l_size, self.l_size) |
|
self.w2 = nn.Linear(self.l_size, self.l_size) |
|
|
|
def forward(self, x): |
|
y = self.w1(x) |
|
y = self.nonlin1(y) |
|
y = self.dropout1(y) |
|
y = self.w2(y) |
|
y = self.nonlin2(y) |
|
out = x + y |
|
return out |
|
|
|
class ResidualFCNet(nn.Module): |
|
|
|
def __init__(self, num_inputs, num_classes, num_filts, depth=4): |
|
super(ResidualFCNet, self).__init__() |
|
self.inc_bias = False |
|
self.class_emb = nn.Linear(num_filts, num_classes, bias=self.inc_bias) |
|
layers = [] |
|
layers.append(nn.Linear(num_inputs, num_filts)) |
|
layers.append(nn.ReLU(inplace=True)) |
|
for i in range(depth): |
|
layers.append(ResLayer(num_filts)) |
|
self.feats = torch.nn.Sequential(*layers) |
|
|
|
def forward(self, x, class_of_interest=None, return_feats=False): |
|
loc_emb = self.feats(x) |
|
if return_feats: |
|
return loc_emb |
|
if class_of_interest is None: |
|
class_pred = self.class_emb(loc_emb) |
|
else: |
|
class_pred = self.eval_single_class(loc_emb, class_of_interest) |
|
return torch.sigmoid(class_pred) |
|
|
|
def eval_single_class(self, x, class_of_interest): |
|
if self.inc_bias: |
|
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest] |
|
else: |
|
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) |
|
|
|
class LinNet(nn.Module): |
|
def __init__(self, num_inputs, num_classes): |
|
super(LinNet, self).__init__() |
|
self.num_layers = 0 |
|
self.inc_bias = False |
|
self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias) |
|
self.feats = nn.Identity() |
|
|
|
def forward(self, x, class_of_interest=None, return_feats=False): |
|
loc_emb = self.feats(x) |
|
if return_feats: |
|
return loc_emb |
|
if class_of_interest is None: |
|
class_pred = self.class_emb(loc_emb) |
|
else: |
|
class_pred = self.eval_single_class(loc_emb, class_of_interest) |
|
|
|
return torch.sigmoid(class_pred) |
|
|
|
def eval_single_class(self, x, class_of_interest): |
|
if self.inc_bias: |
|
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest] |
|
else: |
|
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) |
|
|