|
import torch |
|
from torch import nn |
|
from torchvision import models |
|
|
|
|
|
class DeePixBiS(nn.Module): |
|
def __init__(self, pretrained=True): |
|
super().__init__() |
|
weights = pretrained if pretrained else None |
|
dense = models.densenet161(weights=weights) |
|
features = list(dense.features.children()) |
|
self.enc = nn.Sequential(*features[:8]) |
|
self.dec = nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0) |
|
self.linear = nn.Linear(14 * 14, 1) |
|
|
|
def forward(self, x): |
|
enc = self.enc(x) |
|
dec = self.dec(enc) |
|
out_map = torch.sigmoid(dec) |
|
|
|
out = self.linear(out_map.view(-1, 14 * 14)) |
|
out = torch.sigmoid(out) |
|
out = torch.flatten(out) |
|
return out_map, out |
|
|