import math import types import torch import torch.nn as nn import torch.nn.functional as F from .lseg_blocks import FeatureFusionBlock, Interpolate, _make_encoder, FeatureFusionBlock_custom, forward_vit import clip import numpy as np import pandas as pd import os class depthwise_clipseg_conv(nn.Module): def __init__(self): super(depthwise_clipseg_conv, self).__init__() self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) def depthwise_clipseg(self, x, channels): x = torch.cat([self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], dim=1) return x def forward(self, x): channels = x.shape[1] out = self.depthwise_clipseg(x, channels) return out class depthwise_conv(nn.Module): def __init__(self, kernel_size=3, stride=1, padding=1): super(depthwise_conv, self).__init__() self.depthwise = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x): # support for 4D tensor with NCHW C, H, W = x.shape[1:] x = x.reshape(-1, 1, H, W) x = self.depthwise(x) x = x.view(-1, C, H, W) return x class depthwise_block(nn.Module): def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): super(depthwise_block, self).__init__() self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) if activation == 'relu': self.activation = nn.ReLU() elif activation == 'lrelu': self.activation = nn.LeakyReLU() elif activation == 'tanh': self.activation = nn.Tanh() def forward(self, x, act=True): x = self.depthwise(x) if act: x = self.activation(x) return x class bottleneck_block(nn.Module): def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): super(bottleneck_block, self).__init__() self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) if activation == 'relu': self.activation = nn.ReLU() elif activation == 'lrelu': self.activation = nn.LeakyReLU() elif activation == 'tanh': self.activation = nn.Tanh() def forward(self, x, act=True): sum_layer = x.max(dim=1, keepdim=True)[0] x = self.depthwise(x) x = x + sum_layer if act: x = self.activation(x) return x class BaseModel(torch.nn.Module): def load(self, path): """Load model from file. Args: path (str): file path """ parameters = torch.load(path, map_location=torch.device("cpu")) if "optimizer" in parameters: parameters = parameters["model"] self.load_state_dict(parameters) def _make_fusion_block(features, use_bn): return FeatureFusionBlock_custom( features, activation=nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, ) class LSeg(BaseModel): def __init__( self, head, features=256, backbone="clip_vitl16_384", readout="project", channels_last=False, use_bn=False, **kwargs, ): super(LSeg, self).__init__() self.channels_last = channels_last hooks = { "clip_vitl16_384": [5, 11, 17, 23], "clipRN50x16_vitl16_384": [5, 11, 17, 23], "clip_vitb32_384": [2, 5, 8, 11], } # Instantiate backbone and reassemble blocks self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( backbone, features, groups=1, expand=False, exportable=False, hooks=hooks[backbone], use_readout=readout, ) self.scratch.refinenet1 = _make_fusion_block(features, use_bn) self.scratch.refinenet2 = _make_fusion_block(features, use_bn) self.scratch.refinenet3 = _make_fusion_block(features, use_bn) self.scratch.refinenet4 = _make_fusion_block(features, use_bn) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp() if backbone in ["clipRN50x16_vitl16_384"]: self.out_c = 768 else: self.out_c = 512 self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) self.arch_option = kwargs["arch_option"] if self.arch_option == 1: self.scratch.head_block = bottleneck_block(activation=kwargs["activation"]) self.block_depth = kwargs['block_depth'] elif self.arch_option == 2: self.scratch.head_block = depthwise_block(activation=kwargs["activation"]) self.block_depth = kwargs['block_depth'] self.scratch.output_conv = head self.text = clip.tokenize(self.labels) def forward(self, x, labelset=''): if labelset == '': text = self.text else: text = clip.tokenize(labelset) if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) text = text.to(x.device) self.logit_scale = self.logit_scale.to(x.device) text_features = self.clip_pretrained.encode_text(text) image_features = self.scratch.head1(path_1) imshape = image_features.shape image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) logits_per_image = self.logit_scale * image_features.half() @ text_features.t() out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2) if self.arch_option in [1, 2]: for _ in range(self.block_depth - 1): out = self.scratch.head_block(out) out = self.scratch.head_block(out, False) out = self.scratch.output_conv(out) return out class LSegNet(LSeg): """Network for semantic segmentation.""" def __init__(self, labels, path=None, scale_factor=0.5, crop_size=480, **kwargs): features = kwargs["features"] if "features" in kwargs else 256 kwargs["use_bn"] = True self.crop_size = crop_size self.scale_factor = scale_factor self.labels = labels head = nn.Sequential( Interpolate(scale_factor=2, mode="bilinear", align_corners=True), ) super().__init__(head, **kwargs) if path is not None: self.load(path)