| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os |
| import clip |
| import sys |
| import numpy as np |
| from models.seg_post_model.models import SegModel |
|
|
| from torchvision.ops import roi_align |
|
|
|
|
| class Counting_with_SD_features_loca(nn.Module): |
| def __init__(self, scale_factor): |
| super(Counting_with_SD_features_loca, self).__init__() |
| self.adapter = adapter_roi_loca() |
| self.regressor = regressor_with_SD_features() |
|
|
|
|
| class Counting_with_SD_features_dino_vit_c3(nn.Module): |
| def __init__(self, scale_factor, vit=None): |
| super(Counting_with_SD_features_dino_vit_c3, self).__init__() |
| self.adapter = adapter_roi_loca() |
| self.regressor = regressor_with_SD_features_seg_vit_c3() |
|
|
| class Counting_with_SD_features_track(nn.Module): |
| def __init__(self, scale_factor, vit=None): |
| super(Counting_with_SD_features_track, self).__init__() |
| self.adapter = adapter_roi_loca() |
| self.regressor = regressor_with_SD_features_tra() |
|
|
|
|
| class adapter_roi_loca(nn.Module): |
| def __init__(self, pool_size=[3, 3]): |
| super(adapter_roi_loca, self).__init__() |
| self.pool_size = pool_size |
| self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1) |
| self.pool = nn.MaxPool2d(2) |
| self.fc = nn.Linear(256 * 3 * 3, 768) |
| self.initialize_weights() |
| def forward(self, x, boxes): |
| num_of_boxes = boxes.shape[1] |
| rois = [] |
| bs, _, h, w = x.shape |
| if h != 512 or w != 512: |
| x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False) |
| if bs == 1: |
| boxes = torch.cat([ |
| torch.arange( |
| bs, requires_grad=False |
| ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1), |
| boxes.flatten(0, 1), |
| ], dim=1) |
| rois = roi_align( |
| x, |
| boxes=boxes, output_size=3, |
| spatial_scale=1.0 / 8, aligned=True |
| ) |
| rois = torch.mean(rois, dim=0, keepdim=True) |
| else: |
| boxes = torch.cat([ |
| boxes.flatten(0, 1), |
| ], dim=1).split(num_of_boxes, dim=0) |
| rois = roi_align( |
| x, |
| boxes=boxes, output_size=3, |
| spatial_scale=1.0 / 8, aligned=True |
| ) |
| rois = rois.split(num_of_boxes, dim=0) |
| rois = torch.stack(rois, dim=0) |
| rois = torch.mean(rois, dim=1, keepdim=False) |
| x = self.conv1(rois) |
| x = x.view(x.size(0), -1) |
| x = self.fc(x) |
| return x |
| |
| def forward_boxes(self, x, boxes): |
| num_of_boxes = boxes.shape[1] |
| rois = [] |
| bs, _, h, w = x.shape |
| if h != 512 or w != 512: |
| x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False) |
| if bs == 1: |
| boxes = torch.cat([ |
| torch.arange( |
| bs, requires_grad=False |
| ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1), |
| boxes.flatten(0, 1), |
| ], dim=1) |
| rois = roi_align( |
| x, |
| boxes=boxes, output_size=3, |
| spatial_scale=1.0 / 8, aligned=True |
| ) |
| |
| else: |
| raise NotImplementedError |
| x = self.conv1(rois) |
| x = x.view(x.size(0), -1) |
| x = self.fc(x) |
| return x |
|
|
| def initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
| nn.init.xavier_normal_(m.weight) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
|
|
|
|
| class regressor_with_SD_features(nn.Module): |
| def __init__(self): |
| super(regressor_with_SD_features, self).__init__() |
| self.layer1 = nn.Sequential( |
| nn.Conv2d(324, 256, kernel_size=1, stride=1), |
| nn.LeakyReLU(), |
| nn.LayerNorm((64, 64)) |
| ) |
| self.layer2 = nn.Sequential( |
| nn.Conv2d(256, 128, kernel_size=3, padding=1), |
| nn.LeakyReLU(), |
| nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1), |
| ) |
| self.layer3 = nn.Sequential( |
| nn.Conv2d(128, 64, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1), |
| ) |
| self.layer4 = nn.Sequential( |
| nn.Conv2d(64, 32, kernel_size=3, padding=1), |
| nn.LeakyReLU(), |
| nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1), |
| ) |
| self.conv = nn.Sequential( |
| nn.Conv2d(32, 1, kernel_size=1), |
| nn.ReLU() |
| ) |
| self.norm = nn.LayerNorm(normalized_shape=(64, 64)) |
| self.initialize_weights() |
|
|
| def forward(self, attn_stack, feature_list): |
| attn_stack = self.norm(attn_stack) |
| unet_feature = feature_list[-1] |
| attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) |
| unet_feature = unet_feature * attn_stack_mean |
| unet_feature = torch.cat([unet_feature, attn_stack], dim=1) |
| x = self.layer1(unet_feature) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
| out = self.conv(x) |
| return out / 100 |
| |
| def initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
| nn.init.xavier_normal_(m.weight) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| from models.enc_model.unet_parts import * |
|
|
|
|
| class regressor_with_SD_features_seg_vit_c3(nn.Module): |
| def __init__(self, n_channels=3, n_classes=2, bilinear=False): |
| super(regressor_with_SD_features_seg_vit_c3, self).__init__() |
| self.n_channels = n_channels |
| self.n_classes = n_classes |
| self.bilinear = bilinear |
| self.norm = nn.LayerNorm(normalized_shape=(64, 64)) |
| self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1) |
| self.vit_model = SegModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False) |
| self.vit = self.vit_model.net |
|
|
| def forward(self, img, attn_stack, feature_list): |
| attn_stack = attn_stack[:, [1,3], ...] |
| attn_stack = self.norm(attn_stack) |
| unet_feature = feature_list[-1] |
| unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True) |
|
|
| x = torch.cat([unet_feature_mean, attn_stack], dim=1) |
|
|
| if x.shape[-1] != 512: |
| x = F.interpolate(x, size=(512, 512), mode="bilinear") |
| x = self.inc_0(x) |
|
|
|
|
| |
| out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy()) |
| if out.dtype == np.uint16: |
| out = out.astype(np.int16) |
| out = torch.from_numpy(out).unsqueeze(0).to(x.device) |
| return out |
| |
| def initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
| nn.init.xavier_normal_(m.weight) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| class regressor_with_SD_features_tra(nn.Module): |
| def __init__(self, n_channels=2, n_classes=2, bilinear=False): |
| super(regressor_with_SD_features_tra, self).__init__() |
| self.n_channels = n_channels |
| self.n_classes = n_classes |
| self.bilinear = bilinear |
| self.norm = nn.LayerNorm(normalized_shape=(64, 64)) |
|
|
| |
| self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1) |
| self.vit_model = SegModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False) |
| self.vit = self.vit_model.net |
|
|
| self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1) |
| self.mlp = nn.Linear(64 * 64, 320) |
|
|
| def forward_seg(self, img, attn_stack, feature_list, mask, training=False): |
| attn_stack = attn_stack[:, [1,3], ...] |
| attn_stack = self.norm(attn_stack) |
| unet_feature = feature_list[-1] |
| unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True) |
| x = torch.cat([unet_feature_mean, attn_stack], dim=1) |
|
|
| if x.shape[-1] != 512: |
| x = F.interpolate(x, size=(512, 512), mode="bilinear") |
| x = self.inc_0(x) |
| feat = x |
| |
| out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy()) |
| if out.dtype == np.uint16: |
| out = out.astype(np.int16) |
| out = torch.from_numpy(out).unsqueeze(0).to(x.device) |
| return out, 0., feat |
|
|
| def forward(self, attn_prev, feature_list_prev, attn_after, feature_list_after): |
| assert attn_prev.shape == attn_after.shape, "attn_prev and attn_after must have the same shape" |
| n_instances = attn_prev.shape[0] |
| attn_prev = self.norm(attn_prev) |
| attn_after = self.norm(attn_after) |
| |
| x = torch.cat([attn_prev, attn_after], dim=1) |
| |
| x = self.inc_1(x) |
| x = x.view(1, n_instances, -1) |
| x = self.mlp(x) |
| |
| return x |
|
|
| |
| |
| def initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
| nn.init.xavier_normal_(m.weight) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|