import torch from torch import nn from torchvision import models class DeePixBiS(nn.Module): def __init__(self, pretrained=True): super().__init__() weights = pretrained if pretrained else None dense = models.densenet161(weights=weights) features = list(dense.features.children()) self.enc = nn.Sequential(*features[:8]) self.dec = nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0) self.linear = nn.Linear(14 * 14, 1) def forward(self, x): enc = self.enc(x) dec = self.dec(enc) out_map = torch.sigmoid(dec) # print(out_map.shape) out = self.linear(out_map.view(-1, 14 * 14)) out = torch.sigmoid(out) out = torch.flatten(out) return out_map, out