Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import torchvision.models | |
from torch import nn | |
import torch | |
import torch.nn.functional as F | |
from AV.models.layers import * | |
from torchvision.models.convnext import convnext_tiny, ConvNeXt_Tiny_Weights | |
import numpy as np | |
import math | |
from torchvision import models | |
import copy | |
class PGNet(nn.Module): | |
def __init__(self, input_ch=3, resnet='convnext_tiny', num_classes=3, use_cuda=False, pretrained=True,centerness=False, centerness_map_size=[128,128],use_global_semantic=False): | |
super(PGNet, self).__init__() | |
self.resnet = resnet | |
base_model = convnext_tiny | |
# layers = list(base_model(pretrained=pretrained,num_classes=num_classes,input_ch=input_ch).children())[:cut] | |
self.use_high_semantic = False | |
cut = 6 | |
if pretrained: | |
layers = list(base_model(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1).features)[:cut] | |
else: | |
layers = list(base_model().features)[:cut] | |
base_layers = nn.Sequential(*layers) | |
self.use_global_semantic = use_global_semantic | |
### global momentum | |
if self.use_global_semantic: | |
self.pg_fusion = PGFusion() | |
self.base_layers_global_momentum = copy.deepcopy(base_layers) | |
set_requires_grad(self.base_layers_global_momentum,requires_grad=False) | |
# self.stage = [SaveFeatures(base_layers[0][1])] # stage 1 c=96 | |
self.stage = [] | |
self.stage.append(SaveFeatures(base_layers[0][1])) # stem c=96 | |
self.stage.append(SaveFeatures(base_layers[1][2])) # stage 1 c=96 | |
self.stage.append(SaveFeatures(base_layers[3][2])) # stage 2 c=192 | |
self.stage.append(SaveFeatures(base_layers[5][8])) # stage 3 c=384 | |
# self.stage.append(SaveFeatures(base_layers[7][2])) # stage 5 c=768 | |
self.up2 = DBlock(384, 192) | |
self.up3 = DBlock(192, 96) | |
self.up4 = DBlock(96, 96) | |
# final convolutional layers | |
# predict artery, vein and vessel | |
self.seg_head = SegmentationHead(96, num_classes, 3, upsample=4) | |
self.sn_unet = base_layers | |
self.num_classes = num_classes | |
self.bn_out = nn.BatchNorm2d(3) | |
#self.av_cross = AV_Cross(block=4,kernel_size=1) | |
# use centerness block | |
self.centerness = centerness | |
if self.centerness and centerness_map_size[0] == 128: | |
# block 1 | |
self.cenBlock1 = [ | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
] | |
self.cenBlock1 = nn.Sequential(*self.cenBlock1) | |
# centerness block | |
self.cenBlockMid = [ | |
nn.Conv2d(96, 48, kernel_size=3, padding=1, bias=False), | |
nn.BatchNorm2d(48), | |
# nn.Conv2d(48, 48, kernel_size=3, padding=3, bias=False), | |
# nn.BatchNorm2d(48), | |
nn.Conv2d(48, 96, kernel_size=1, padding=0, bias=False), | |
] | |
self.cenBlockMid = nn.Sequential(*self.cenBlockMid) | |
self.cenBlockFinal = [ | |
nn.BatchNorm2d(96), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(96, 3, kernel_size=1, padding=0, bias=True), | |
nn.Sigmoid() | |
] | |
self.cenBlockFinal = nn.Sequential(*self.cenBlockFinal) | |
def forward(self, x,y=None): | |
x = self.sn_unet(x) | |
global_rep = None | |
if self.use_global_semantic: | |
global_rep = self.base_layers_global_momentum(y) | |
x = self.pg_fusion(x,global_rep) | |
if len(x.shape) == 4 and x.shape[2] != x.shape[3]: | |
B, H, W, C = x.shape | |
x = x.permute(0, 3, 1, 2).contiguous() | |
elif len(x.shape) == 3: | |
B, L, C = x.shape | |
h = int(L ** 0.5) | |
x = x.view(B, h, h, C) | |
x = x.permute(0, 3, 1, 2).contiguous() | |
else: | |
x = x | |
if self.use_high_semantic: | |
high_out = x.clone() | |
else: | |
high_out = x.clone() | |
if self.resnet == 'swin_t' or self.resnet == 'convnext_tiny': | |
# feature = self.stage[1:] | |
feature = self.stage[::-1] | |
# head = feature[0] | |
skip = feature[1:] | |
# x = self.up1(x,skip[0].features) | |
x = self.up2(x, skip[0].features) | |
x = self.up3(x, skip[1].features) | |
x = self.up4(x, skip[2].features) | |
x_out = self.seg_head(x) | |
######################## | |
# baseline output | |
# artery, vein and vessel | |
output = x_out.clone() | |
#av cross | |
#output = self.av_cross(output) | |
#output = F.relu(self.bn_out(output)) | |
# use centerness block | |
centerness_maps = None | |
if self.centerness: | |
block1 = self.cenBlock1(self.stage[1].features) # [96,64] | |
_block1 = self.cenBlockMid(block1) # [96,64] | |
block1 = block1 + _block1 | |
blocks = [block1] | |
blocks = torch.cat(blocks, dim=1) | |
# print("blocks", blocks.shape) | |
centerness_maps = self.cenBlockFinal(blocks) | |
# print("maps:", centerness_maps.shape) | |
return output, centerness_maps | |
def forward_patch_rep(self, x): | |
patch_rep = self.sn_unet(x) | |
return patch_rep | |
def forward_global_rep_momentum(self, x): | |
global_rep = self.base_layers_global_momentum(x) | |
return global_rep | |
def close(self): | |
for sf in self.stage: sf.remove() | |
def close(self): | |
for sf in self.stage: sf.remove() | |
# set requies_grad=Fasle to avoid computation | |
def set_requires_grad(nets, requires_grad=False): | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
pretrained_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False).view((1, 3, 1, 1)) | |
pretrained_std = torch.tensor([0.229, 0.224, 0.225], requires_grad=False).view((1, 3, 1, 1)) | |
if __name__ == '__main__': | |
s = PGNet(input_ch=3, resnet='convnext_tiny',centerness=True, pretrained=False,use_global_semantic=False) | |
x = torch.randn(2, 3, 256, 256) | |
y,Y2 = s(x) | |
print(y.shape) | |
print(Y2.shape) | |
# pt = torch.load(r'F:\dw\MICCAI2023-STS-2D\segmentation\log\2023_07_25_18_10_10\G_0.pkl') | |
# print(pt) | |
# import torchvision.models as models | |
# m = models.vit_b_16(pretrained=False) | |
# print(m) | |
# m = resnet18() | |
# m_list = list(m.children()) | |
# def hook(module, input, output): | |
# print('fafafafgafa') | |
# print(input[0].shape) | |
# print(output[0].shape) | |
# m_list[0].register_forward_hook(hook) | |
# | |
# | |
# y = m(x) | |