# © Recursion Pharmaceuticals 2024 import timm.models.vision_transformer as vit import torch def generate_2d_sincos_pos_embeddings( embedding_dim: int, length: int, scale: float = 10000.0, use_class_token: bool = True, num_modality: int = 1, ) -> torch.nn.Parameter: """ Generate 2Dimensional sin/cosine positional embeddings Parameters ---------- embedding_dim : int embedding dimension used in vit length : int number of tokens along height or width of image after patching (assuming square) scale : float scale for sin/cos functions use_class_token : bool True - add zero vector to be added to class_token, False - no vector added num_modality: number of modalities. If 0, a single modality is assumed. Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced. Returns ------- positional_encoding : torch.Tensor positional encoding to add to vit patch encodings [num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim] (w/ or w/o cls_token) """ linear_positions = torch.arange(length, dtype=torch.float32) height_mesh, width_mesh = torch.meshgrid( linear_positions, linear_positions, indexing="ij" ) positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings positional_weights = ( torch.arange(positional_dim, dtype=torch.float32) / positional_dim ) positional_weights = 1.0 / (scale**positional_weights) height_weights = torch.outer(height_mesh.flatten(), positional_weights) width_weights = torch.outer(width_mesh.flatten(), positional_weights) positional_encoding = torch.cat( [ torch.sin(height_weights), torch.cos(height_weights), torch.sin(width_weights), torch.cos(width_weights), ], dim=1, )[None, :, :] # repeat positional encoding for multiple channel modalities positional_encoding = positional_encoding.repeat(1, num_modality, 1) if use_class_token: class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32) positional_encoding = torch.cat([class_token, positional_encoding], dim=1) positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False) return positional_encoding class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc] def __init__( self, img_size: int, patch_size: int, embed_dim: int, bias: bool = True, ) -> None: super().__init__( img_size=img_size, patch_size=patch_size, in_chans=1, # in_chans is used by self.proj, which we override anyway embed_dim=embed_dim, norm_layer=None, flatten=False, bias=bias, ) # channel-agnostic MAE has a single projection for all chans self.proj = torch.nn.Conv2d( 1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias ) def forward(self, x: torch.Tensor) -> torch.Tensor: in_chans = x.shape[1] x = torch.stack( [self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2 ) # single project for all chans x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC return x class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc] def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: # rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586 to_cat = [] if self.cls_token is not None: to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) # TODO: upgrade timm to get access to register tokens # if self.vit_backbone.reg_token is not None: # to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) # MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs # this supports having CA-MAEs actually be channel-agnostic at inference time if self.no_embed_class: x = x + self.pos_embed[:, : x.shape[1]] if to_cat: x = torch.cat(to_cat + [x], dim=1) else: if to_cat: x = torch.cat(to_cat + [x], dim=1) x = x + self.pos_embed[:, : x.shape[1]] return self.pos_drop(x) # type: ignore[no-any-return] def channel_agnostic_vit( vit_backbone: vit.VisionTransformer, max_in_chans: int ) -> vit.VisionTransformer: # replace patch embedding with channel-agnostic version vit_backbone.patch_embed = ChannelAgnosticPatchEmbed( img_size=vit_backbone.patch_embed.img_size[0], patch_size=vit_backbone.patch_embed.patch_size[0], embed_dim=vit_backbone.embed_dim, ) # replace positional embedding with channel-agnostic version vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings( embedding_dim=vit_backbone.embed_dim, length=vit_backbone.patch_embed.grid_size[0], use_class_token=vit_backbone.cls_token is not None, num_modality=max_in_chans, ) # change the class to be ChannelAgnostic so that it actually uses the new _pos_embed vit_backbone.__class__ = ChannelAgnosticViT return vit_backbone def sincos_positional_encoding_vit( vit_backbone: vit.VisionTransformer, scale: float = 10000.0 ) -> vit.VisionTransformer: """Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model. Parameters ---------- vit_backbone : timm.models.vision_transformer.VisionTransformer the constructed vision transformer from timm scale : float (default 10000.0) hyperparameter for sincos positional embeddings, recommend keeping at 10,000 Returns ------- timm.models.vision_transformer.VisionTransformer the same ViT but with fixed no-grad positional encodings to add to vit patch encodings """ # length: number of tokens along height or width of image after patching (assuming square) length = ( vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0] ) pos_embeddings = generate_2d_sincos_pos_embeddings( vit_backbone.embed_dim, length=length, scale=scale, use_class_token=vit_backbone.cls_token is not None, ) # note, if the model had weight_init == 'skip', this might get overwritten vit_backbone.pos_embed = pos_embeddings return vit_backbone def vit_small_patch16_256(**kwargs): default_kwargs = dict( img_size=256, in_chans=6, num_classes=0, fc_norm=None, class_token=True, drop_path_rate=0.1, init_values=0.0001, block_fn=vit.ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) for k, v in kwargs.items(): default_kwargs[k] = v return vit.vit_small_patch16_224(**default_kwargs) def vit_small_patch32_512(**kwargs): default_kwargs = dict( img_size=512, in_chans=6, num_classes=0, fc_norm=None, class_token=True, drop_path_rate=0.1, init_values=0.0001, block_fn=vit.ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) for k, v in kwargs.items(): default_kwargs[k] = v return vit.vit_small_patch32_384(**default_kwargs) def vit_base_patch8_256(**kwargs): default_kwargs = dict( img_size=256, in_chans=6, num_classes=0, fc_norm=None, class_token=True, drop_path_rate=0.1, init_values=0.0001, block_fn=vit.ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) for k, v in kwargs.items(): default_kwargs[k] = v return vit.vit_base_patch8_224(**default_kwargs) def vit_base_patch16_256(**kwargs): default_kwargs = dict( img_size=256, in_chans=6, num_classes=0, fc_norm=None, class_token=True, drop_path_rate=0.1, init_values=0.0001, block_fn=vit.ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) for k, v in kwargs.items(): default_kwargs[k] = v return vit.vit_base_patch16_224(**default_kwargs) def vit_base_patch32_512(**kwargs): default_kwargs = dict( img_size=512, in_chans=6, num_classes=0, fc_norm=None, class_token=True, drop_path_rate=0.1, init_values=0.0001, block_fn=vit.ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) for k, v in kwargs.items(): default_kwargs[k] = v return vit.vit_base_patch32_384(**default_kwargs) def vit_large_patch8_256(**kwargs): default_kwargs = dict( img_size=256, in_chans=6, num_classes=0, fc_norm=None, class_token=True, patch_size=8, embed_dim=1024, depth=24, num_heads=16, drop_path_rate=0.3, init_values=0.0001, block_fn=vit.ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) for k, v in kwargs.items(): default_kwargs[k] = v return vit.VisionTransformer(**default_kwargs) def vit_large_patch16_256(**kwargs): default_kwargs = dict( img_size=256, in_chans=6, num_classes=0, fc_norm=None, class_token=True, drop_path_rate=0.3, init_values=0.0001, block_fn=vit.ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) for k, v in kwargs.items(): default_kwargs[k] = v return vit.vit_large_patch16_384(**default_kwargs)