from __future__ import annotations from typing import Union from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from .configuration_m3d_lamed import LamedConfig from abc import ABC, abstractmethod from torch import Tensor import math from typing import Any, Dict, List import torch import torch.nn as nn from typing import Optional, Tuple, Type from monai.networks.blocks import PatchEmbed import numpy as np import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange from collections.abc import Sequence from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.nets import ViT class BinaryDiceLoss(nn.Module): def __init__(self, smooth=1, p=2, reduction='mean'): super(BinaryDiceLoss, self).__init__() self.smooth = smooth self.p = p self.reduction = reduction def forward(self, predict, target): predict = torch.sigmoid(predict) target_ = target.clone().float() target_[target == -1] = 0 assert predict.shape[0] == target.shape[0], "predict & target batch size don't match\n" + str(predict.shape) + '\n' + str(target.shape[0]) predict = predict.contiguous().view(predict.shape[0], -1) target_ = target_.contiguous().view(target_.shape[0], -1) num = torch.sum(torch.mul(predict, target_), dim=1) den = torch.sum(predict, dim=1) + torch.sum(target_, dim=1) + self.smooth dice_score = 2*num / den dice_loss = 1 - dice_score # dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0] dice_loss_avg = dice_loss.sum() / dice_loss.shape[0] return dice_loss_avg class BCELoss(nn.Module): def __init__(self): super(BCELoss, self).__init__() self.criterion = nn.BCEWithLogitsLoss() def forward(self, predict, target): assert predict.shape == target.shape, 'predict & target shape do not match\n' + str(predict.shape) + '\n' + str(target.shape) target_ = target.clone() target_[target == -1] = 0 ce_loss = self.criterion(predict, target_.float()) return ce_loss class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) class TwoWayTransformer(nn.Module): def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, ) -> None: """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer embedding_dim (int): the channel dimension for the input embeddings num_heads (int): the number of heads for multihead attention. Must divide embedding_dim mlp_dim (int): the channel dimension internal to the MLP block activation (nn.Module): the activation to use in the MLP block """ super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() for i in range(depth): self.layers.append( TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), ) ) self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, ) -> Tuple[Tensor, Tensor]: """ Args: image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. image_pe (torch.Tensor): the positional encoding to add to the image. Must have the same shape as image_embedding. point_embedding (torch.Tensor): the embedding to add to the query points. Must have shape B x N_points x embedding_dim for any N_points. Returns: torch.Tensor: the processed point_embedding torch.Tensor: the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C bs, c, h, w, d = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) # Prepare queries queries = point_embedding keys = image_embedding # Apply transformer blocks and final layernorm for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Apply the final attention layer from the points to the image q = queries + point_embedding k = keys + image_pe attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm_final_attn(queries) return queries, keys class TwoWayAttentionBlock(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys class Attention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) # Attention _, _, _, c_per_head = q.shape attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens attn = attn / math.sqrt(c_per_head) attn = torch.softmax(attn, dim=-1) # Get output out = attn @ v out = self._recombine_heads(out) out = self.out_proj(out) return out # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 1, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size # self.patch_embed = PatchEmbed( # kernel_size=(patch_size, patch_size), # stride=(patch_size, patch_size), # in_chans=in_chans, # embed_dim=embed_dim, # ) self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, spatial_dims=3, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, img_size // patch_size, embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) print('patch embedded shape: ', x.shape) # embedded: [8, 768, 6, 6, 6] if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention2( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class Attention2(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) return x def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn class PromptEncoder(nn.Module): def __init__( self, embed_dim: int, image_embedding_size: Tuple[int, int, int], input_image_size: Tuple[int, int, int], mask_in_chans: int, activation: Type[nn.Module] = nn.GELU, ) -> None: """ Encodes prompts for input to SAM's mask decoder. Arguments: embed_dim (int): The prompts' embedding dimension image_embedding_size (tuple(int, int)): The spatial size of the image embedding, as (H, W). input_image_size (int): The padded size of the image as input to the image encoder, as (H, W). mask_in_chans (int): The number of hidden channels used for encoding input masks. activation (nn.Module): The activation to use when encoding input masks. """ super().__init__() self.embed_dim = embed_dim self.input_image_size = input_image_size self.image_embedding_size = image_embedding_size self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1, embed_dim) self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2]) self.mask_downscaling = nn.Sequential( nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans // 4), activation(), nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans), activation(), nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), ) self.no_mask_embed = nn.Embedding(1, embed_dim) def get_dense_pe(self) -> torch.Tensor: """ Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the image encoding. Returns: torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) def _embed_points( self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 3) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: """Embeds mask inputs.""" mask_embedding = self.mask_downscaling(masks) return mask_embedding def _get_batch_size( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], text_embedding: Optional[torch.Tensor], ) -> int: """ Gets the batch size of the output given the batch size of the input prompts. """ if points is not None: return points[0].shape[0] elif boxes is not None: return boxes.shape[0] elif masks is not None: return masks.shape[0] elif text_embedding is not None: return text_embedding.shape[0] else: return 1 def _get_device(self) -> torch.device: return self.point_embeddings[0].weight.device def forward( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], text_embedding: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed text: test prompt (B, 768) Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ # print('prompt encoder here...') bs = self._get_batch_size(points, boxes, masks, text_embedding) sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device(), dtype=self.point_embeddings[0].weight.dtype) # print('sparse_embeddings ', sparse_embeddings.shape) if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) if text_embedding is not None: sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1) # print('box_embeddings ', box_embeddings.shape) # print('sparse_embeddings after box/point/text', sparse_embeddings.shape) if masks is not None: dense_embeddings = self._embed_masks(masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand( bs, -1, int(self.image_embedding_size[0]), int(self.image_embedding_size[1]), int(self.image_embedding_size[2]) ) return sparse_embeddings, dense_embeddings class PositionEmbeddingRandom(nn.Module): """ Positional encoding using random spatial frequencies. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: super().__init__() if scale is None or scale <= 0.0: scale = 1.0 self.register_buffer( "positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats)), ) def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) def forward(self, size: Tuple[int, int, int]) -> torch.Tensor: """Generate positional encoding for a grid of the specified size.""" h, w, d = size device: Any = self.positional_encoding_gaussian_matrix.device dtype = self.positional_encoding_gaussian_matrix.dtype grid = torch.ones((h, w, d), device=device, dtype=dtype) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 z_embed = grid.cumsum(dim=2) - 0.5 y_embed = y_embed / h x_embed = x_embed / w z_embed = z_embed / d pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) return pe.permute(3, 0, 1, 2) # C x H x W x D def forward_with_coords( self, coords_input: torch.Tensor, image_size: Tuple[int, int] ) -> torch.Tensor: """Positionally encode points that are not normalized to [0,1].""" coords = coords_input.clone() coords[:, :, 0] = coords[:, :, 0] / image_size[1] coords[:, :, 1] = coords[:, :, 1] / image_size[0] coords[:, :, 2] = coords[:, :, 2] / image_size[2] return self._pe_encoding(coords.to(torch.float)) # B x N x C class MaskDecoder(nn.Module): def __init__( self, *, image_encoder_type: str, transformer_dim: int, transformer: nn.Module, num_multimask_outputs: int = 3, activation: Type[nn.Module] = nn.GELU, iou_head_depth: int = 3, iou_head_hidden_dim: int = 256, image_size, patch_size, ) -> None: """ Predicts masks given an image and prompt embeddings, using a transformer architecture. Arguments: transformer_dim (int): the channel dimension of the transformer transformer (nn.Module): the transformer used to predict masks num_multimask_outputs (int): the number of masks to predict when disambiguating masks activation (nn.Module): the type of activation to use when upscaling masks iou_head_depth (int): the depth of the MLP used to predict mask quality iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality """ super().__init__() self.transformer_dim = transformer_dim self.transformer = transformer self.num_multimask_outputs = num_multimask_outputs self.iou_token = nn.Embedding(1, transformer_dim) self.num_mask_tokens = num_multimask_outputs + 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) if image_encoder_type == 'swin_vit': self.feat_shape = image_size / patch_size self.output_upscaling = nn.Sequential( nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), nn.LayerNorm( (transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # swin activation(), nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), # swin # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1), # vit activation(), ) else: self.feat_shape = image_size / patch_size * 2 self.output_upscaling = nn.Sequential( nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), nn.LayerNorm( (transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # vit activation(), nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1), activation(), ) self.output_hypernetworks_mlps = nn.ModuleList( [ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens) ] ) self.iou_prediction_head = MLP( transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth ) self.txt_align_upscaled_embedding = nn.Linear(768, 96) def forward( self, image_embeddings: torch.Tensor, text_embedding: Optional[torch.Tensor], image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality """ # print('--------------decoder here--------------') masks, iou_pred = self.predict_masks( image_embeddings=image_embeddings, text_embedding=text_embedding, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, ) # Select the correct mask or masks for output if multimask_output: mask_slice = slice(1, None) else: mask_slice = slice(0, 1) masks = masks[:, mask_slice, :, :, :] iou_pred = iou_pred[:, mask_slice] # Prepare output return masks, iou_pred def predict_masks( self, image_embeddings: torch.Tensor, text_embedding: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # [2, 7=(5+2), 256] # Expand per-image data in batch direction to be per-mask if image_embeddings.shape[0] != tokens.shape[0]: src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) else: src = image_embeddings src = src + dense_prompt_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w, d = src.shape # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w, d) # print('src ', src.shape) # vit:[B, 768, 12, 12, 6], swin: [B, 6, 6, 3] upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w, d = upscaled_embedding.shape # print('hyper_in ', hyper_in.shape) # [2, 4, 96] # print('upscaled_embedding ', upscaled_embedding.shape) # [2, 96, 24, 24, 12]* masks = (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view(b, -1, h, w, d) # print('masks here ', masks.shape) # [2, 4, 24, 24, 12] if text_embedding is not None: # text_embedding: B x 768, upscaled_embedding: B x c x h x w x d => B x 1 x h x w x d text_embedding_down = self.txt_align_upscaled_embedding(text_embedding).unsqueeze(dim=1) upscaled_embedding = upscaled_embedding.view(b, c, h * w * d) # print('text_embedding_down ', text_embedding_down.shape) # [2, 1, 96] # text_embedding_norm = F.normalize(text_embedding_down, dim=-1) # upscaled_embedding_norm = F.normalize(upscaled_embedding, dim=1) # sim = (text_embedding_norm @ upscaled_embedding_norm).view(b, -1, h, w, d) # print(text_embedding_down.shape, upscaled_embedding.shape) sim = (text_embedding_down @ upscaled_embedding).view(b, -1, h, w, d) # print('sim ', sim.shape) # [B, 1, 24, 24, 12] sim = sim.repeat(1, masks.shape[1], 1, 1, 1) # print('sim after', sim.shape) # [B, 4, 24, 24, 12] masks = masks + sim # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred # Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) self.sigmoid_output = sigmoid_output def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: x = F.sigmoid(x) return x class Sam(nn.Module): mask_threshold: float = 0.0 image_format: str = "RGB" def __init__( self, image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375], ) -> None: """ SAM predicts object masks from an image and input prompts. Arguments: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @property def device(self) -> Any: return self.pixel_mean.device @torch.no_grad() def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool, ) -> List[Dict[str, torch.Tensor]]: """ Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly. Arguments: batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt key can be excluded if it is not present. 'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model. 'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W). 'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already transformed to the input frame of the model. 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of the model. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW. multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single mask. Returns: (list(dict)): A list over input images, where each element is as dictionary with the following keys. 'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of input prompts, C is determined by multimask_output, and (H, W) is the original size of the image. 'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC. 'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction. """ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) image_embeddings = self.image_encoder(input_images) outputs = [] for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=image_record.get("boxes", None), masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) masks = self.postprocess_masks( low_res_masks, input_size=image_record["image"].shape[-2:], original_size=image_record["original_size"], ) masks = masks > self.mask_threshold outputs.append( { "masks": masks, "iou_predictions": iou_predictions, "low_res_logits": low_res_masks, } ) return outputs def postprocess_masks( self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], ) -> torch.Tensor: """ Remove padding and upscale masks to the original image size. Arguments: masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. Used to remove padding. original_size (tuple(int, int)): The original size of the image before resizing for input to the model, in (H, W) format. Returns: (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. """ masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, ) masks = masks[..., : input_size[0], : input_size[1]] masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors # TODO x = (x - self.pixel_mean) / self.pixel_std # Pad h, w = x.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x """ Examples:: # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) """ def build_sam_vit_3d(args, checkpoint=None): print('build_sam_vit_3d...') return _build_sam( image_encoder_type='vit', embed_dim=768, patch_size=args.patch_size, checkpoint=checkpoint, image_size=args.image_size, ) sam_model_registry = { "vit": build_sam_vit_3d, } def _build_sam( image_encoder_type, embed_dim, patch_size, checkpoint, image_size, ): mlp_dim = 3072 num_layers = 12 num_heads = 12 pos_embed = 'perceptron' dropout_rate = 0.0 image_encoder = ViT( in_channels=1, img_size=image_size, patch_size=patch_size, hidden_size=embed_dim, mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, pos_embed=pos_embed, classification=False, dropout_rate=dropout_rate, ) image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))] if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f, map_location='cpu')['state_dict'] encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k} image_encoder.load_state_dict(encoder_dict) print(f'===> image_encoder.load_param: {checkpoint}') sam = Sam( image_encoder=image_encoder, prompt_encoder=PromptEncoder( embed_dim=embed_dim, image_embedding_size=image_embedding_size, input_image_size=image_size, mask_in_chans=16, ), mask_decoder=MaskDecoder( image_encoder_type=image_encoder_type, num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, embedding_dim=embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, image_size=np.array(image_size), patch_size=np.array(patch_size), ), pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) sam.eval() return sam class SegVol(nn.Module): def __init__(self, image_encoder, mask_decoder, prompt_encoder, roi_size, patch_size, ): super().__init__() self.image_encoder = image_encoder self.mask_decoder = mask_decoder self.prompt_encoder = prompt_encoder self.feat_shape = np.array(roi_size)/np.array(patch_size) def forward(self, image, text_emb=None, text=None, boxes=None, points=None): bs = image.shape[0] img_shape = (image.shape[2], image.shape[3], image.shape[4]) image_embedding, _ = self.image_encoder(image) image_embedding = image_embedding.transpose(1, 2).view(bs, -1, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2])) logits = self.forward_decoder(image_embedding, img_shape, text_emb=text_emb, text=text, boxes=boxes, points=points) return logits def forward_decoder(self, image_embedding, img_shape, text_emb=None, text=None, boxes=None, points=None): text_embedding = text_emb sparse_embeddings, dense_embeddings = self.prompt_encoder( points=None, boxes=None, masks=None, text_embedding=text_embedding, ) dense_pe = self.prompt_encoder.get_dense_pe() low_res_masks, _ = self.mask_decoder( image_embeddings=image_embedding, text_embedding = text_embedding, image_pe=dense_pe, sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False) return logits def build_segmentation_module(config, **kwargs): segmentation_module = getattr(config, 'segmentation_module') if 'segvol' in segmentation_module.lower(): sam_model = sam_model_registry['vit'](args=config, checkpoint=None) seg_model = SegVol( image_encoder=sam_model.image_encoder, mask_decoder=sam_model.mask_decoder, prompt_encoder=sam_model.prompt_encoder, roi_size=config.image_size, patch_size=config.patch_size, ) return seg_model else: raise ValueError(f'Unknown segmentation module: {segmentation_module}') class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": 'identity'} class SpatialPoolingProjector(nn.Module): def __init__(self, image_size, patch_size, in_dim, out_dim, layer_type, layer_num, pooling_type='spatial', pooling_size=2): super().__init__() self.in_dim = in_dim self.pooling_size = pooling_size self.num_patches_pre = [img // pch for img, pch in zip(image_size, patch_size)] self.num_patches_post = [num // pooling_size for num in self.num_patches_pre] if layer_type == 'linear': depth = int(layer_num) modules = [nn.Linear(in_dim, out_dim)] for _ in range(1, depth): modules.append(nn.Linear(out_dim, out_dim)) self.projector = nn.Sequential(*modules) elif layer_type == 'mlp': depth = int(layer_num) modules = [nn.Linear(in_dim, out_dim)] for _ in range(1, depth): modules.append(nn.GELU()) modules.append(nn.Linear(out_dim, out_dim)) self.projector = nn.Sequential(*modules) else: print("Projector error!") self.pooling_type = pooling_type def forward(self, x): B = x.shape[0] # B*N*D if self.pooling_type == 'spatial': to_3d = Rearrange("b (p1 p2 p3) d -> b d p1 p2 p3", b=B, d=self.in_dim, p1=self.num_patches_pre[0], p2=self.num_patches_pre[1], p3=self.num_patches_pre[2]) x = to_3d(x) x = F.avg_pool3d(x, kernel_size=self.pooling_size, stride=self.pooling_size) to_seq = Rearrange("b d p1 p2 p3 -> b (p1 p2 p3) d", b=B, d=self.in_dim, p1=self.num_patches_post[0], p2=self.num_patches_post[1], p3=self.num_patches_post[2]) x = to_seq(x) elif self.pooling_type == 'sequence': x = x.permute(0, 2, 1) #b d n x = F.avg_pool1d(x, kernel_size=self.pooling_size**3, stride=self.pooling_size**3) x = x.permute(0, 2, 1) #b n d x = rearrange(x, "b n d -> (b n) d") x = self.projector(x) x = rearrange(x, "(b n) d -> b n d", b=B) return x @property def proj_out_num(self): num = 1 for n in self.num_patches_post: num *= n return num class Minigpt(nn.Module): def __init__(self, config=None): super(Minigpt, self).__init__() # c*4 is the input size, and c is the output size for the linear layer inc, ouc = config.mm_hidden_size, config.hidden_size self.linear = nn.Linear(inc * 4, ouc) def forward(self, x): # x is the input tensor with shape [b, num_tokens, c] b, num_tokens, c = x.shape # Check if num_tokens is divisible by 4 if num_tokens % 4 != 0: raise ValueError("num_tokens must be divisible by 4") # Reshape x to [b, num_tokens/4, c*4] x = x.view(b, num_tokens // 4, c * 4) # Apply the linear transformation x = self.linear(x) return x class Vanilla(nn.Module): def __init__(self, config=None): super(Vanilla, self).__init__() # c*4 is the input size, and c is the output size for the linear layer inc, ouc = config.mm_hidden_size, config.hidden_size self.linear = nn.Linear(inc * 4, ouc) def forward(self, x): b, num_tokens, c = x.shape # Check if num_tokens is divisible by 4 if num_tokens % 4 != 0: raise ValueError("num_tokens must be divisible by 4") # First, reshape to [b, num_tokens//4, 4, c] x = x.view(b, num_tokens // 4, 4, c) # Then, permute to interleave the tokens x = x.permute(0, 1, 3, 2).contiguous() # Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens x = x.view(b, num_tokens // 4, c * 4) # Apply the linear transformation x = self.linear(x) return x def build_mm_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, 'mm_projector_type') if projector_type == 'linear': return nn.Linear(config.mm_hidden_size, config.hidden_size) elif projector_type == 'spp': return SpatialPoolingProjector(image_size=config.image_size, patch_size=config.patch_size, in_dim=config.mm_hidden_size, out_dim=config.hidden_size, layer_type=config.proj_layer_type, layer_num=config.proj_layer_num, pooling_type=config.proj_pooling_type, pooling_size=config.proj_pooling_size) elif projector_type == 'identity': return IdentityMap() else: raise ValueError(f'Unknown projector type: {projector_type}') class myViT(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " ViT supports Torchscript but only works for Pytorch after 1.8. """ def __init__( self, in_channels: int, img_size: Sequence[int] | int, patch_size: Sequence[int] | int, hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", classification: bool = False, num_classes: int = 2, dropout_rate: float = 0.0, spatial_dims: int = 3, post_activation="Tanh", qkv_bias: bool = False, save_attn: bool = False, ) -> None: """ Args: in_channels (int): dimension of input channels. img_size (Union[Sequence[int], int]): dimension of input image. patch_size (Union[Sequence[int], int]): dimension of patch size. hidden_size (int, optional): dimension of hidden layer. Defaults to 768. mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072. num_layers (int, optional): number of transformer blocks. Defaults to 12. num_heads (int, optional): number of attention heads. Defaults to 12. pos_embed (str, optional): position embedding layer type. Defaults to "conv". classification (bool, optional): bool argument to determine if classification is used. Defaults to False. num_classes (int, optional): number of classes if classification is used. Defaults to 2. dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. spatial_dims (int, optional): number of spatial dimensions. Defaults to 3. post_activation (str, optional): add a final acivation function to the classification head when `classification` is True. Default to "Tanh" for `nn.Tanh()`. Set to other values to remove this function. qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') # for 3-channel with image size of (128,128,128), 24 layers and classification backbone >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) # for 3-channel with image size of (224,224), 12 layers and classification backbone >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2) """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") self.hidden_size = hidden_size self.classification = classification self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, img_size=img_size, patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, pos_embed=pos_embed, dropout_rate=dropout_rate, spatial_dims=spatial_dims, ) self.blocks = nn.ModuleList( [ TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) for i in range(num_layers) ] ) self.norm = nn.LayerNorm(hidden_size) if self.classification: self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) # if post_activation == "Tanh": # self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh()) # else: # self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore def forward(self, x): x = self.patch_embedding(x) if hasattr(self, "cls_token"): cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] for blk in self.blocks: x = blk(x) hidden_states_out.append(x) x = self.norm(x) # if hasattr(self, "classification_head"): # x = self.classification_head(x[:, 0]) return x, hidden_states_out class ViT3DTower(nn.Module): def __init__(self, config): super().__init__() self.config = config self.select_layer = config.vision_select_layer self.select_feature = config.vision_select_feature self.vision_tower = myViT( in_channels=self.config.image_channel, img_size=self.config.image_size, patch_size=self.config.patch_size, pos_embed="perceptron", spatial_dims=len(self.config.patch_size), classification=True, ) def forward(self, images): last_feature, hidden_states = self.vision_tower(images) if self.select_layer == -1: image_features = last_feature elif self.select_layer < -1: image_features = hidden_states[self.select_feature] else: raise ValueError(f'Unexpected select layer: {self.select_layer}') if self.select_feature == 'patch': image_features = image_features[:, 1:] elif self.select_feature == 'cls_patch': image_features = image_features else: raise ValueError(f'Unexpected select feature: {self.select_feature}') return image_features @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def hidden_size(self): return self.vision_tower.hidden_size def build_vision_tower(config, **kwargs): vision_tower = getattr(config, 'vision_tower', None) if 'vit3d' in vision_tower.lower(): return ViT3DTower(config, **kwargs) else: raise ValueError(f'Unknown vision tower: {vision_tower}') class LamedMetaModel: def __init__(self, config): super(LamedMetaModel, self).__init__(config) self.config = config self.seg_enable = False if hasattr(config, "vision_tower"): self.vision_tower = build_vision_tower(config) self.mm_projector = build_mm_projector(config) if hasattr(config, "segmentation_module") and config.segmentation_module is not None: self.seg_enable = True self.seg_module = build_segmentation_module(config) self.seg_projector = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(inplace=True), nn.Linear(config.hidden_size, config.mm_hidden_size), nn.Dropout(0.1), ) self.dice_loss = BinaryDiceLoss() self.bce_loss = BCELoss() def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) return vision_tower def initialize_vision_modules(self, model_args): self.config.image_channel = model_args.image_channel self.config.image_size = model_args.image_size self.config.patch_size = model_args.patch_size self.config.vision_tower = model_args.vision_tower self.config.vision_select_layer = model_args.vision_select_layer self.config.vision_select_feature = model_args.vision_select_feature self.config.mm_projector_type = model_args.mm_projector_type self.config.proj_layer_type = model_args.proj_layer_type self.config.proj_layer_num = model_args.proj_layer_num self.config.proj_pooling_type = model_args.proj_pooling_type self.config.proj_pooling_size = model_args.proj_pooling_size # vision tower if self.get_vision_tower() is None: self.vision_tower = build_vision_tower(self.config) # If you have a more robust vision encoder, try freezing the vision tower by requires_grad_(False) if model_args.pretrain_vision_model is not None: vision_model_weights = torch.load(model_args.pretrain_vision_model, map_location='cpu') self.vision_tower.vision_tower.load_state_dict(vision_model_weights, strict=True) self.config.mm_hidden_size = self.vision_tower.hidden_size # mm_projector if getattr(self, 'mm_projector', None) is None: self.mm_projector = build_mm_projector(self.config) if model_args.pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=True) def initialize_seg_modules(self, model_args): self.config.segmentation_module = model_args.segmentation_module # segmentation_module if getattr(self, 'segmentation_module', None) is None: self.seg_module = build_segmentation_module(self.config) self.seg_projector = nn.Sequential( nn.Linear(self.config.hidden_size, self.config.hidden_size), nn.ReLU(inplace=True), nn.Linear(self.config.hidden_size, self.config.mm_hidden_size), nn.Dropout(0.1), ) self.seg_enable = True if model_args.pretrain_seg_module is not None: seg_module_weights = torch.load(model_args.pretrain_seg_module, map_location='cpu') self.seg_module.load_state_dict(seg_module_weights, strict=True) self.dice_loss = BinaryDiceLoss() self.bce_loss = BCELoss() class LamedMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images, ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels else: image_features = self.encode_images(images) inputs_embeds = self.get_model().embed_tokens(input_ids) inputs_embeds = torch.cat( (inputs_embeds[:, :1, :], image_features, inputs_embeds[:, (image_features.shape[1] + 1):, :]), dim=1) return None, position_ids, attention_mask, past_key_values, inputs_embeds, labels def initialize_vision_tokenizer(self, model_args, tokenizer): num_new_tokens = model_args.num_new_tokens self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: # we add 4 new tokens # if new tokens need input, please train input_embeddings for p in self.get_input_embeddings().parameters(): p.requires_grad = True # if new tokens need predict, please train output_embeddings for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings = embed_tokens_weight elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") class LamedLlamaModel(LamedMetaModel, LlamaModel): config_class = LamedConfig def __init__(self, config: LlamaConfig): super(LamedLlamaModel, self).__init__(config) class LamedLlamaForCausalLM(LamedMetaForCausalLM, LlamaForCausalLM): config_class = LamedConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = LamedLlamaModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, images: Optional[torch.FloatTensor] = None, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, segs: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: input_ids_pre = input_ids if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, ) try: seg_ids = torch.nonzero(torch.sum(segs, dim=(1, 2, 3, 4))).flatten().tolist() except: seg_ids = [] if self.get_model().seg_enable and seg_ids: outputs = super().forward( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, return_dict=return_dict ) output_hidden_states = outputs.hidden_states last_hidden_state = output_hidden_states[-1] seg_token_mask = input_ids_pre[:, 1:] == self.config.seg_token_id seg_token_mask = torch.cat( [ seg_token_mask, torch.zeros((seg_token_mask.shape[0], 1), dtype=seg_token_mask.dtype).cuda(), ], dim=1, ) seg_prompts = [] for i in seg_ids: if torch.sum(seg_token_mask[i]) == 1: seg_token = last_hidden_state[i][seg_token_mask[i]] seg_prompt = self.get_model().seg_projector(seg_token) elif torch.sum(seg_token_mask[i]) > 1: seg_tokens = last_hidden_state[i][seg_token_mask[i]] seg_token = torch.mean(seg_tokens, dim=0, keepdim=True) seg_prompt = self.get_model().seg_projector(seg_token) else: seg_prompt = torch.zeros([1, self.config.mm_hidden_size], dtype=last_hidden_state.dtype, device=last_hidden_state.device) seg_prompts.append(seg_prompt) seg_prompts = torch.cat(seg_prompts, dim=0) logits = self.get_model().seg_module(images[seg_ids], text_emb=seg_prompts) loss_dice = self.get_model().dice_loss(logits, segs[seg_ids]) loss_bce = self.get_model().bce_loss(logits, segs[seg_ids]) seg_loss = loss_dice + loss_bce outputs.loss = outputs.loss + seg_loss return outputs else: return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) @torch.no_grad() def generate( self, images: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None, seg_enable: bool = False, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor, Any]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: ( inputs, position_ids, attention_mask, _, inputs_embeds, _ ) = self.prepare_inputs_for_multimodal( inputs, position_ids, attention_mask, None, None, images, ) else: inputs_embeds = self.get_model().embed_tokens(inputs) if seg_enable: outputs = super().generate( inputs_embeds=inputs_embeds, output_hidden_states=True, return_dict_in_generate=True, **kwargs ) output_hidden_states = outputs.hidden_states output_ids = outputs.sequences seg_token_mask = output_ids[:, 1:] == self.config.seg_token_id last_tensors = [tuple[-1] for tuple in output_hidden_states] last_hidden_state = torch.cat(last_tensors[1:], dim=1) seg_prompts = [] noseg_ids = [] for i in range(len(seg_token_mask)): if torch.sum(seg_token_mask[i]) == 1: seg_token = last_hidden_state[i][seg_token_mask[i]] seg_prompt = self.get_model().seg_projector(seg_token) elif torch.sum(seg_token_mask[i]) > 1: seg_tokens = last_hidden_state[i][seg_token_mask[i]] seg_token = torch.mean(seg_tokens, dim=0, keepdim=True) seg_prompt = self.get_model().seg_projector(seg_token) else: noseg_ids.append(i) seg_prompt = torch.zeros([1, self.config.mm_hidden_size], dtype=last_hidden_state.dtype, device=last_hidden_state.device) seg_prompts.append(seg_prompt) seg_prompts = torch.cat(seg_prompts, dim=0) logits = self.get_model().seg_module(images, seg_prompts) logits[noseg_ids] = -torch.inf return output_ids, logits else: output_ids = super().generate( inputs_embeds=inputs_embeds, **kwargs ) return output_ids def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs['images'] = images return inputs