RIP-AV-su-lab / AV /models /network.py
weidai00's picture
Upload 72 files
6c0075d verified
# -*- coding: utf-8 -*-
import torchvision.models
from torch import nn
import torch
import torch.nn.functional as F
from AV.models.layers import *
from torchvision.models.convnext import convnext_tiny, ConvNeXt_Tiny_Weights
import numpy as np
import math
from torchvision import models
import copy
class PGNet(nn.Module):
def __init__(self, input_ch=3, resnet='convnext_tiny', num_classes=3, use_cuda=False, pretrained=True,centerness=False, centerness_map_size=[128,128],use_global_semantic=False):
super(PGNet, self).__init__()
self.resnet = resnet
base_model = convnext_tiny
# layers = list(base_model(pretrained=pretrained,num_classes=num_classes,input_ch=input_ch).children())[:cut]
self.use_high_semantic = False
cut = 6
if pretrained:
layers = list(base_model(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1).features)[:cut]
else:
layers = list(base_model().features)[:cut]
base_layers = nn.Sequential(*layers)
self.use_global_semantic = use_global_semantic
### global momentum
if self.use_global_semantic:
self.pg_fusion = PGFusion()
self.base_layers_global_momentum = copy.deepcopy(base_layers)
set_requires_grad(self.base_layers_global_momentum,requires_grad=False)
# self.stage = [SaveFeatures(base_layers[0][1])] # stage 1 c=96
self.stage = []
self.stage.append(SaveFeatures(base_layers[0][1])) # stem c=96
self.stage.append(SaveFeatures(base_layers[1][2])) # stage 1 c=96
self.stage.append(SaveFeatures(base_layers[3][2])) # stage 2 c=192
self.stage.append(SaveFeatures(base_layers[5][8])) # stage 3 c=384
# self.stage.append(SaveFeatures(base_layers[7][2])) # stage 5 c=768
self.up2 = DBlock(384, 192)
self.up3 = DBlock(192, 96)
self.up4 = DBlock(96, 96)
# final convolutional layers
# predict artery, vein and vessel
self.seg_head = SegmentationHead(96, num_classes, 3, upsample=4)
self.sn_unet = base_layers
self.num_classes = num_classes
self.bn_out = nn.BatchNorm2d(3)
#self.av_cross = AV_Cross(block=4,kernel_size=1)
# use centerness block
self.centerness = centerness
if self.centerness and centerness_map_size[0] == 128:
# block 1
self.cenBlock1 = [
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
]
self.cenBlock1 = nn.Sequential(*self.cenBlock1)
# centerness block
self.cenBlockMid = [
nn.Conv2d(96, 48, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(48),
# nn.Conv2d(48, 48, kernel_size=3, padding=3, bias=False),
# nn.BatchNorm2d(48),
nn.Conv2d(48, 96, kernel_size=1, padding=0, bias=False),
]
self.cenBlockMid = nn.Sequential(*self.cenBlockMid)
self.cenBlockFinal = [
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.Conv2d(96, 3, kernel_size=1, padding=0, bias=True),
nn.Sigmoid()
]
self.cenBlockFinal = nn.Sequential(*self.cenBlockFinal)
def forward(self, x,y=None):
x = self.sn_unet(x)
global_rep = None
if self.use_global_semantic:
global_rep = self.base_layers_global_momentum(y)
x = self.pg_fusion(x,global_rep)
if len(x.shape) == 4 and x.shape[2] != x.shape[3]:
B, H, W, C = x.shape
x = x.permute(0, 3, 1, 2).contiguous()
elif len(x.shape) == 3:
B, L, C = x.shape
h = int(L ** 0.5)
x = x.view(B, h, h, C)
x = x.permute(0, 3, 1, 2).contiguous()
else:
x = x
if self.use_high_semantic:
high_out = x.clone()
else:
high_out = x.clone()
if self.resnet == 'swin_t' or self.resnet == 'convnext_tiny':
# feature = self.stage[1:]
feature = self.stage[::-1]
# head = feature[0]
skip = feature[1:]
# x = self.up1(x,skip[0].features)
x = self.up2(x, skip[0].features)
x = self.up3(x, skip[1].features)
x = self.up4(x, skip[2].features)
x_out = self.seg_head(x)
########################
# baseline output
# artery, vein and vessel
output = x_out.clone()
#av cross
#output = self.av_cross(output)
#output = F.relu(self.bn_out(output))
# use centerness block
centerness_maps = None
if self.centerness:
block1 = self.cenBlock1(self.stage[1].features) # [96,64]
_block1 = self.cenBlockMid(block1) # [96,64]
block1 = block1 + _block1
blocks = [block1]
blocks = torch.cat(blocks, dim=1)
# print("blocks", blocks.shape)
centerness_maps = self.cenBlockFinal(blocks)
# print("maps:", centerness_maps.shape)
return output, centerness_maps
def forward_patch_rep(self, x):
patch_rep = self.sn_unet(x)
return patch_rep
def forward_global_rep_momentum(self, x):
global_rep = self.base_layers_global_momentum(x)
return global_rep
def close(self):
for sf in self.stage: sf.remove()
def close(self):
for sf in self.stage: sf.remove()
# set requies_grad=Fasle to avoid computation
def set_requires_grad(nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
pretrained_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False).view((1, 3, 1, 1))
pretrained_std = torch.tensor([0.229, 0.224, 0.225], requires_grad=False).view((1, 3, 1, 1))
if __name__ == '__main__':
s = PGNet(input_ch=3, resnet='convnext_tiny',centerness=True, pretrained=False,use_global_semantic=False)
x = torch.randn(2, 3, 256, 256)
y,Y2 = s(x)
print(y.shape)
print(Y2.shape)
# pt = torch.load(r'F:\dw\MICCAI2023-STS-2D\segmentation\log\2023_07_25_18_10_10\G_0.pkl')
# print(pt)
# import torchvision.models as models
# m = models.vit_b_16(pretrained=False)
# print(m)
# m = resnet18()
# m_list = list(m.children())
# def hook(module, input, output):
# print('fafafafgafa')
# print(input[0].shape)
# print(output[0].shape)
# m_list[0].register_forward_hook(hook)
#
#
# y = m(x)