fasd / Model.py
Alisher Amantay
first commit
9067733
raw
history blame
780 Bytes
import torch
from torch import nn
from torchvision import models
class DeePixBiS(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
weights = pretrained if pretrained else None
dense = models.densenet161(weights=weights)
features = list(dense.features.children())
self.enc = nn.Sequential(*features[:8])
self.dec = nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0)
self.linear = nn.Linear(14 * 14, 1)
def forward(self, x):
enc = self.enc(x)
dec = self.dec(enc)
out_map = torch.sigmoid(dec)
# print(out_map.shape)
out = self.linear(out_map.view(-1, 14 * 14))
out = torch.sigmoid(out)
out = torch.flatten(out)
return out_map, out