print("Importing external...") import torch from torch import nn import torch.nn.functional as F from timm.models.efficientvit_mit import ( ConvNormAct, FusedMBConv, MBConv, ResidualBlock, efficientvit_l1, ) from timm.layers import GELUTanh def val2list(x: list or tuple or any, repeat_time=1): if isinstance(x, (list, tuple)): return list(x) return [x for _ in range(repeat_time)] def resize( x: torch.Tensor, size: any or None = None, scale_factor: list[float] or None = None, mode: str = "bicubic", align_corners: bool or None = False, ) -> torch.Tensor: if mode in {"bilinear", "bicubic"}: return F.interpolate( x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, ) elif mode in {"nearest", "area"}: return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) else: raise NotImplementedError(f"resize(mode={mode}) not implemented.") class UpSampleLayer(nn.Module): def __init__( self, mode="bicubic", size: int or tuple[int, int] or list[int] or None = None, factor=2, align_corners=False, ): super(UpSampleLayer, self).__init__() self.mode = mode self.size = val2list(size, 2) if size is not None else None self.factor = None if self.size is not None else factor self.align_corners = align_corners def forward(self, x: torch.Tensor) -> torch.Tensor: if ( self.size is not None and tuple(x.shape[-2:]) == self.size ) or self.factor == 1: return x return resize(x, self.size, self.factor, self.mode, self.align_corners) class DAGBlock(nn.Module): def __init__( self, inputs: dict[str, nn.Module], merge: str, post_input: nn.Module or None, middle: nn.Module, outputs: dict[str, nn.Module], ): super(DAGBlock, self).__init__() self.input_keys = list(inputs.keys()) self.input_ops = nn.ModuleList(list(inputs.values())) self.merge = merge self.post_input = post_input self.middle = middle self.output_keys = list(outputs.keys()) self.output_ops = nn.ModuleList(list(outputs.values())) def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: feat = [ op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops) ] if self.merge == "add": feat = list_sum(feat) elif self.merge == "cat": feat = torch.concat(feat, dim=1) else: raise NotImplementedError if self.post_input is not None: feat = self.post_input(feat) feat = self.middle(feat) for key, op in zip(self.output_keys, self.output_ops): feature_dict[key] = op(feat) return feature_dict def list_sum(x: list) -> any: return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) class SegHead(nn.Module): def __init__( self, fid_list: list[str], in_channel_list: list[int], stride_list: list[int], head_stride: int, head_width: int, head_depth: int, expand_ratio: float, middle_op: str, final_expand: float or None, n_classes: int, dropout=0, norm="bn2d", act_func="hswish", ): super(SegHead, self).__init__() # exceptions to adapt effvit to timm if act_func == "gelu": act_func = GELUTanh else: raise ValueError(f"act_func {act_func} not supported") if norm == "bn2d": norm_layer = nn.BatchNorm2d else: raise ValueError(f"norm {norm} not supported") inputs = {} for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list): factor = stride // head_stride if factor == 1: inputs[fid] = ConvNormAct( in_channel, head_width, 1, norm_layer=norm_layer, act_layer=act_func ) else: inputs[fid] = nn.Sequential( ConvNormAct( in_channel, head_width, 1, norm_layer=norm_layer, act_layer=act_func, ), UpSampleLayer(factor=factor), ) self.in_keys = inputs.keys() self.in_ops = nn.ModuleList(inputs.values()) middle = [] for _ in range(head_depth): if middle_op == "mbconv": block = MBConv( head_width, head_width, expand_ratio=expand_ratio, norm_layer=norm_layer, act_layer=(act_func, act_func, None), ) elif middle_op == "fmbconv": block = FusedMBConv( head_width, head_width, expand_ratio=expand_ratio, norm_layer=norm_layer, act_layer=(act_func, None), ) else: raise NotImplementedError middle.append(ResidualBlock(block, nn.Identity())) self.middle = nn.Sequential(*middle) self.out_layer = nn.Sequential( *[ None if final_expand is None else ConvNormAct( head_width, head_width * final_expand, 1, norm_layer=norm_layer, act_layer=act_func, ), ConvNormAct( head_width * (final_expand or 1), n_classes, 1, bias=True, dropout=dropout, norm_layer=None, act_layer=None, ), ] ) def forward(self, feature_map_list): t_feat_maps = [ self.in_ops[ind](feature_map_list[ind]) for ind in range(len(feature_map_list)) ] t_feat_map = list_sum(t_feat_maps) t_feat_map = self.middle(t_feat_map) out = self.out_layer(t_feat_map) return out class EfficientViT_l1_r224(nn.Module): def __init__( self, out_channels, out_ds_factor=1, decoder_size="small", pretrained=False, use_norm_params=False, ): if decoder_size == "small": head_width = 32 head_depth = 1 middle_op = "mbconv" elif decoder_size == "medium": head_width = 64 head_depth = 3 middle_op = "mbconv" elif decoder_size == "large": head_width = 256 head_depth = 3 middle_op = "fmbconv" super(EfficientViT_l1_r224, self).__init__() self.bbone = efficientvit_l1( num_classes=0, features_only=True, pretrained=pretrained ) self.head = SegHead( fid_list=["stage4", "stage3", "stage2"], in_channel_list=[512, 256, 128], stride_list=[32, 16, 8], head_stride=out_ds_factor, head_width=head_width, head_depth=head_depth, expand_ratio=4, middle_op=middle_op, final_expand=8, n_classes=out_channels, act_func="gelu", ) # [optional] deactivate normalization if not use_norm_params: for module in self.modules(): if ( isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) ): module.weight.requires_grad_(False) module.bias.requires_grad_(False) def forward(self, x): feat = self.bbone(x) out = self.head([feat[3], feat[2], feat[1]]) return out