Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						20239f9
	
1
								Parent(s):
							
							b507f8e
								
add initial files
Browse files- .gitignore +40 -0
 - files/images/Laysan_Albatross_0050_870.jpg +0 -0
 - layers/__init__.py +2 -0
 - layers/independent_mlp.py +69 -0
 - layers/transformer_layers.py +54 -0
 - load_model.py +226 -0
 - models/__init__.py +4 -0
 - models/individual_landmark_convnext.py +110 -0
 - models/individual_landmark_resnet.py +141 -0
 - models/individual_landmark_vit.py +366 -0
 - models/vit_baseline.py +239 -0
 - requirements.txt +5 -1
 - utils/__init__.py +6 -0
 - utils/data_utils/__init__.py +5 -0
 - utils/data_utils/class_balanced_distributed_sampler.py +100 -0
 - utils/data_utils/class_balanced_sampler.py +31 -0
 - utils/data_utils/dataset_utils.py +161 -0
 - utils/data_utils/reversible_affine_transform.py +82 -0
 - utils/data_utils/transform_utils.py +118 -0
 - utils/get_landmark_coordinates.py +41 -0
 - utils/misc_utils.py +135 -0
 - utils/visualize_att_maps.py +135 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,40 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # editor settings
         
     | 
| 2 | 
         
            +
            .idea
         
     | 
| 3 | 
         
            +
            .vscode
         
     | 
| 4 | 
         
            +
            _darcs
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # compilation and distribution
         
     | 
| 7 | 
         
            +
            __pycache__
         
     | 
| 8 | 
         
            +
            _ext
         
     | 
| 9 | 
         
            +
            *.pyc
         
     | 
| 10 | 
         
            +
            *.pyd
         
     | 
| 11 | 
         
            +
            *.so
         
     | 
| 12 | 
         
            +
            *.dll
         
     | 
| 13 | 
         
            +
            *.egg-info/
         
     | 
| 14 | 
         
            +
            build/
         
     | 
| 15 | 
         
            +
            dist/
         
     | 
| 16 | 
         
            +
            wheels/
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            # pytorch/python/numpy formats
         
     | 
| 19 | 
         
            +
            *.pth
         
     | 
| 20 | 
         
            +
            *.pkl
         
     | 
| 21 | 
         
            +
            *.npy
         
     | 
| 22 | 
         
            +
            *.ts
         
     | 
| 23 | 
         
            +
            *.pt
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # ipython/jupyter notebooks
         
     | 
| 26 | 
         
            +
            *.ipynb
         
     | 
| 27 | 
         
            +
            **/.ipynb_checkpoints/
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            # Editor temporaries
         
     | 
| 30 | 
         
            +
            *.swn
         
     | 
| 31 | 
         
            +
            *.swo
         
     | 
| 32 | 
         
            +
            *.swp
         
     | 
| 33 | 
         
            +
            *~
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # Results temporary
         
     | 
| 36 | 
         
            +
            *.png
         
     | 
| 37 | 
         
            +
            *.txt
         
     | 
| 38 | 
         
            +
            *.tsv
         
     | 
| 39 | 
         
            +
            wandb/
         
     | 
| 40 | 
         
            +
            exps/
         
     | 
    	
        files/images/Laysan_Albatross_0050_870.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        layers/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .transformer_layers import *
         
     | 
| 2 | 
         
            +
            from .independent_mlp import *
         
     | 
    	
        layers/independent_mlp.py
    ADDED
    
    | 
         @@ -0,0 +1,69 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # This file contains the implementation of the IndependentMLPs class
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class IndependentMLPs(torch.nn.Module):
         
     | 
| 6 | 
         
            +
                """
         
     | 
| 7 | 
         
            +
                This class implements the MLP used for classification with the option to use an additional independent MLP layer
         
     | 
| 8 | 
         
            +
                """
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                def __init__(self, part_dim, latent_dim, bias=False, num_lin_layers=1, act_layer=True, out_dim=None, stack_dim=-1):
         
     | 
| 11 | 
         
            +
                    """
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                    :param part_dim: Number of parts
         
     | 
| 14 | 
         
            +
                    :param latent_dim: Latent dimension
         
     | 
| 15 | 
         
            +
                    :param bias: Whether to use bias
         
     | 
| 16 | 
         
            +
                    :param num_lin_layers: Number of linear layers
         
     | 
| 17 | 
         
            +
                    :param act_layer: Whether to use activation layer
         
     | 
| 18 | 
         
            +
                    :param out_dim: Output dimension (default: None)
         
     | 
| 19 | 
         
            +
                    :param stack_dim: Dimension to stack the outputs (default: -1)
         
     | 
| 20 | 
         
            +
                    """
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    super().__init__()
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.bias = bias
         
     | 
| 25 | 
         
            +
                    self.latent_dim = latent_dim
         
     | 
| 26 | 
         
            +
                    if out_dim is None:
         
     | 
| 27 | 
         
            +
                        out_dim = latent_dim
         
     | 
| 28 | 
         
            +
                    self.out_dim = out_dim
         
     | 
| 29 | 
         
            +
                    self.part_dim = part_dim
         
     | 
| 30 | 
         
            +
                    self.stack_dim = stack_dim
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    layer_stack = torch.nn.ModuleList()
         
     | 
| 33 | 
         
            +
                    for i in range(part_dim):
         
     | 
| 34 | 
         
            +
                        layer_stack.append(torch.nn.Sequential())
         
     | 
| 35 | 
         
            +
                        for j in range(num_lin_layers):
         
     | 
| 36 | 
         
            +
                            layer_stack[i].add_module(f"fc_{j}", torch.nn.Linear(latent_dim, self.out_dim, bias=bias))
         
     | 
| 37 | 
         
            +
                            if act_layer:
         
     | 
| 38 | 
         
            +
                                layer_stack[i].add_module(f"act_{j}", torch.nn.GELU())
         
     | 
| 39 | 
         
            +
                    self.feature_layers = layer_stack
         
     | 
| 40 | 
         
            +
                    self.reset_weights()
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def __repr__(self):
         
     | 
| 43 | 
         
            +
                    return f"IndependentMLPs(part_dim={self.part_dim}, latent_dim={self.latent_dim}), bias={self.bias}"
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def reset_weights(self):
         
     | 
| 46 | 
         
            +
                    """ Initialize weights with a identity matrix"""
         
     | 
| 47 | 
         
            +
                    for layer in self.feature_layers:
         
     | 
| 48 | 
         
            +
                        for m in layer.modules():
         
     | 
| 49 | 
         
            +
                            if isinstance(m, torch.nn.Linear):
         
     | 
| 50 | 
         
            +
                                # Initialize weights with a truncated normal distribution
         
     | 
| 51 | 
         
            +
                                torch.nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 52 | 
         
            +
                                if m.bias is not None:
         
     | 
| 53 | 
         
            +
                                    torch.nn.init.zeros_(m.bias)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def forward(self, x):
         
     | 
| 56 | 
         
            +
                    """ Input X has the dimensions batch x latent_dim x part_dim """
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    outputs = []
         
     | 
| 59 | 
         
            +
                    for i, layer in enumerate(self.feature_layers):
         
     | 
| 60 | 
         
            +
                        if self.stack_dim == -1:
         
     | 
| 61 | 
         
            +
                            in_ = x[..., i]
         
     | 
| 62 | 
         
            +
                        else:
         
     | 
| 63 | 
         
            +
                            in_ = x[:, i, ...]  # Select feature i
         
     | 
| 64 | 
         
            +
                        out = layer(in_)  # Apply MLP to feature i
         
     | 
| 65 | 
         
            +
                        outputs.append(out)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    x = torch.stack(outputs, dim=self.stack_dim)  # Stack the outputs
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    return x
         
     | 
    	
        layers/transformer_layers.py
    ADDED
    
    | 
         @@ -0,0 +1,54 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Attention Block with option to return the mean of k over heads from attention
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from timm.models.vision_transformer import Attention, Block
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            from typing import Tuple
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class AttentionWQKVReturn(Attention):
         
     | 
| 10 | 
         
            +
                """
         
     | 
| 11 | 
         
            +
                Modifications:
         
     | 
| 12 | 
         
            +
                     - Return the qkv tensors from the attention
         
     | 
| 13 | 
         
            +
                """
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 16 | 
         
            +
                    B, N, C = x.shape
         
     | 
| 17 | 
         
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
         
     | 
| 18 | 
         
            +
                    q, k, v = qkv.unbind(0)
         
     | 
| 19 | 
         
            +
                    q, k = self.q_norm(q), self.k_norm(k)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    if self.fused_attn:
         
     | 
| 22 | 
         
            +
                        x = F.scaled_dot_product_attention(
         
     | 
| 23 | 
         
            +
                            q, k, v,
         
     | 
| 24 | 
         
            +
                            dropout_p=self.attn_drop.p if self.training else 0.,
         
     | 
| 25 | 
         
            +
                        )
         
     | 
| 26 | 
         
            +
                    else:
         
     | 
| 27 | 
         
            +
                        q = q * self.scale
         
     | 
| 28 | 
         
            +
                        attn = q @ k.transpose(-2, -1)
         
     | 
| 29 | 
         
            +
                        attn = attn.softmax(dim=-1)
         
     | 
| 30 | 
         
            +
                        attn = self.attn_drop(attn)
         
     | 
| 31 | 
         
            +
                        x = attn @ v
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    x = x.transpose(1, 2).reshape(B, N, C)
         
     | 
| 34 | 
         
            +
                    x = self.proj(x)
         
     | 
| 35 | 
         
            +
                    x = self.proj_drop(x)
         
     | 
| 36 | 
         
            +
                    return x, torch.stack((q, k, v), dim=0)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class BlockWQKVReturn(Block):
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
                Modifications:
         
     | 
| 42 | 
         
            +
                    - Use AttentionWQKVReturn instead of Attention
         
     | 
| 43 | 
         
            +
                    - Return the qkv tensors from the attention
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def forward(self, x: torch.Tensor, return_qkv: bool = False) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 47 | 
         
            +
                    # Note: this is copied from timm.models.vision_transformer.Block with modifications.
         
     | 
| 48 | 
         
            +
                    x_attn, qkv = self.attn(self.norm1(x))
         
     | 
| 49 | 
         
            +
                    x = x + self.drop_path1(self.ls1(x_attn))
         
     | 
| 50 | 
         
            +
                    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
         
     | 
| 51 | 
         
            +
                    if return_qkv:
         
     | 
| 52 | 
         
            +
                        return x, qkv
         
     | 
| 53 | 
         
            +
                    else:
         
     | 
| 54 | 
         
            +
                        return x
         
     | 
    	
        load_model.py
    ADDED
    
    | 
         @@ -0,0 +1,226 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import copy
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from timm.models import create_model
         
     | 
| 7 | 
         
            +
            from torchvision.models import get_model
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb
         
     | 
| 10 | 
         
            +
            from models.individual_landmark_resnet import IndividualLandmarkResNet
         
     | 
| 11 | 
         
            +
            from models.individual_landmark_convnext import IndividualLandmarkConvNext
         
     | 
| 12 | 
         
            +
            from models.individual_landmark_vit import IndividualLandmarkViT
         
     | 
| 13 | 
         
            +
            from utils import load_state_dict_pdisco
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def load_model_arch(args, num_cls):
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
                Function to load the model
         
     | 
| 19 | 
         
            +
                :param args: Arguments from the command line
         
     | 
| 20 | 
         
            +
                :param num_cls: Number of classes in the dataset
         
     | 
| 21 | 
         
            +
                :return:
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
                if 'resnet' in args.model_arch:
         
     | 
| 24 | 
         
            +
                    num_layers_split = [int(s) for s in args.model_arch if s.isdigit()]
         
     | 
| 25 | 
         
            +
                    num_layers = int(''.join(map(str, num_layers_split)))
         
     | 
| 26 | 
         
            +
                    if num_layers >= 100:
         
     | 
| 27 | 
         
            +
                        timm_model_arch = args.model_arch + ".a1h_in1k"
         
     | 
| 28 | 
         
            +
                    else:
         
     | 
| 29 | 
         
            +
                        timm_model_arch = args.model_arch + ".a1_in1k"
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                if "resnet" in args.model_arch and args.use_torchvision_resnet_model:
         
     | 
| 32 | 
         
            +
                    weights = "DEFAULT" if args.pretrained_start_weights else None
         
     | 
| 33 | 
         
            +
                    base_model = get_model(args.model_arch, weights=weights)
         
     | 
| 34 | 
         
            +
                elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model:
         
     | 
| 35 | 
         
            +
                    if args.eval_only:
         
     | 
| 36 | 
         
            +
                        base_model = create_model(
         
     | 
| 37 | 
         
            +
                            timm_model_arch,
         
     | 
| 38 | 
         
            +
                            pretrained=args.pretrained_start_weights,
         
     | 
| 39 | 
         
            +
                            num_classes=num_cls,
         
     | 
| 40 | 
         
            +
                            output_stride=args.output_stride,
         
     | 
| 41 | 
         
            +
                        )
         
     | 
| 42 | 
         
            +
                    else:
         
     | 
| 43 | 
         
            +
                        base_model = create_model(
         
     | 
| 44 | 
         
            +
                            timm_model_arch,
         
     | 
| 45 | 
         
            +
                            pretrained=args.pretrained_start_weights,
         
     | 
| 46 | 
         
            +
                            drop_path_rate=args.drop_path,
         
     | 
| 47 | 
         
            +
                            num_classes=num_cls,
         
     | 
| 48 | 
         
            +
                            output_stride=args.output_stride,
         
     | 
| 49 | 
         
            +
                        )
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                elif "convnext" in args.model_arch:
         
     | 
| 52 | 
         
            +
                    if args.eval_only:
         
     | 
| 53 | 
         
            +
                        base_model = create_model(
         
     | 
| 54 | 
         
            +
                            args.model_arch,
         
     | 
| 55 | 
         
            +
                            pretrained=args.pretrained_start_weights,
         
     | 
| 56 | 
         
            +
                            num_classes=num_cls,
         
     | 
| 57 | 
         
            +
                            output_stride=args.output_stride,
         
     | 
| 58 | 
         
            +
                        )
         
     | 
| 59 | 
         
            +
                    else:
         
     | 
| 60 | 
         
            +
                        base_model = create_model(
         
     | 
| 61 | 
         
            +
                            args.model_arch,
         
     | 
| 62 | 
         
            +
                            pretrained=args.pretrained_start_weights,
         
     | 
| 63 | 
         
            +
                            drop_path_rate=args.drop_path,
         
     | 
| 64 | 
         
            +
                            num_classes=num_cls,
         
     | 
| 65 | 
         
            +
                            output_stride=args.output_stride,
         
     | 
| 66 | 
         
            +
                        )
         
     | 
| 67 | 
         
            +
                elif "vit" in args.model_arch:
         
     | 
| 68 | 
         
            +
                    if args.eval_only:
         
     | 
| 69 | 
         
            +
                        base_model = create_model(
         
     | 
| 70 | 
         
            +
                            args.model_arch,
         
     | 
| 71 | 
         
            +
                            pretrained=args.pretrained_start_weights,
         
     | 
| 72 | 
         
            +
                            img_size=args.image_size,
         
     | 
| 73 | 
         
            +
                        )
         
     | 
| 74 | 
         
            +
                    else:
         
     | 
| 75 | 
         
            +
                        base_model = create_model(
         
     | 
| 76 | 
         
            +
                            args.model_arch,
         
     | 
| 77 | 
         
            +
                            pretrained=args.pretrained_start_weights,
         
     | 
| 78 | 
         
            +
                            drop_path_rate=args.drop_path,
         
     | 
| 79 | 
         
            +
                            img_size=args.image_size,
         
     | 
| 80 | 
         
            +
                        )
         
     | 
| 81 | 
         
            +
                    vit_patch_size = base_model.patch_embed.proj.kernel_size[0]
         
     | 
| 82 | 
         
            +
                    if args.image_size % vit_patch_size != 0:
         
     | 
| 83 | 
         
            +
                        raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}")
         
     | 
| 84 | 
         
            +
                else:
         
     | 
| 85 | 
         
            +
                    raise ValueError('Model not supported.')
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                return base_model
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def init_pdisco_model(base_model, args, num_cls):
         
     | 
| 91 | 
         
            +
                """
         
     | 
| 92 | 
         
            +
                Function to initialize the model
         
     | 
| 93 | 
         
            +
                :param base_model: Base model
         
     | 
| 94 | 
         
            +
                :param args: Arguments from the command line
         
     | 
| 95 | 
         
            +
                :param num_cls: Number of classes in the dataset
         
     | 
| 96 | 
         
            +
                :return:
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
                # Initialize the network
         
     | 
| 99 | 
         
            +
                if 'convnext' in args.model_arch:
         
     | 
| 100 | 
         
            +
                    sl_channels = base_model.stages[-1].downsample[-1].in_channels
         
     | 
| 101 | 
         
            +
                    fl_channels = base_model.head.in_features
         
     | 
| 102 | 
         
            +
                    model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls,
         
     | 
| 103 | 
         
            +
                                                       sl_channels=sl_channels, fl_channels=fl_channels,
         
     | 
| 104 | 
         
            +
                                                       part_dropout=args.part_dropout, modulation_type=args.modulation_type,
         
     | 
| 105 | 
         
            +
                                                       gumbel_softmax=args.gumbel_softmax,
         
     | 
| 106 | 
         
            +
                                                       gumbel_softmax_temperature=args.gumbel_softmax_temperature,
         
     | 
| 107 | 
         
            +
                                                       gumbel_softmax_hard=args.gumbel_softmax_hard,
         
     | 
| 108 | 
         
            +
                                                       modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
         
     | 
| 109 | 
         
            +
                                                       noise_variance=args.noise_variance)
         
     | 
| 110 | 
         
            +
                elif 'resnet' in args.model_arch:
         
     | 
| 111 | 
         
            +
                    sl_channels = base_model.layer4[0].conv1.in_channels
         
     | 
| 112 | 
         
            +
                    fl_channels = base_model.fc.in_features
         
     | 
| 113 | 
         
            +
                    model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls,
         
     | 
| 114 | 
         
            +
                                                     sl_channels=sl_channels, fl_channels=fl_channels,
         
     | 
| 115 | 
         
            +
                                                     use_torchvision_model=args.use_torchvision_resnet_model,
         
     | 
| 116 | 
         
            +
                                                     part_dropout=args.part_dropout, modulation_type=args.modulation_type,
         
     | 
| 117 | 
         
            +
                                                     gumbel_softmax=args.gumbel_softmax,
         
     | 
| 118 | 
         
            +
                                                     gumbel_softmax_temperature=args.gumbel_softmax_temperature,
         
     | 
| 119 | 
         
            +
                                                     gumbel_softmax_hard=args.gumbel_softmax_hard,
         
     | 
| 120 | 
         
            +
                                                     modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
         
     | 
| 121 | 
         
            +
                                                     noise_variance=args.noise_variance)
         
     | 
| 122 | 
         
            +
                elif 'vit' in args.model_arch:
         
     | 
| 123 | 
         
            +
                    model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls,
         
     | 
| 124 | 
         
            +
                                                  part_dropout=args.part_dropout,
         
     | 
| 125 | 
         
            +
                                                  modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax,
         
     | 
| 126 | 
         
            +
                                                  gumbel_softmax_temperature=args.gumbel_softmax_temperature,
         
     | 
| 127 | 
         
            +
                                                  gumbel_softmax_hard=args.gumbel_softmax_hard,
         
     | 
| 128 | 
         
            +
                                                  modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
         
     | 
| 129 | 
         
            +
                                                  noise_variance=args.noise_variance)
         
     | 
| 130 | 
         
            +
                else:
         
     | 
| 131 | 
         
            +
                    raise ValueError('Model not supported.')
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                return model
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            def load_model_pdisco(args, num_cls):
         
     | 
| 137 | 
         
            +
                """
         
     | 
| 138 | 
         
            +
                Function to load the model
         
     | 
| 139 | 
         
            +
                :param args: Arguments from the command line
         
     | 
| 140 | 
         
            +
                :param num_cls: Number of classes in the dataset
         
     | 
| 141 | 
         
            +
                :return:
         
     | 
| 142 | 
         
            +
                """
         
     | 
| 143 | 
         
            +
                base_model = load_model_arch(args, num_cls)
         
     | 
| 144 | 
         
            +
                model = init_pdisco_model(base_model, args, num_cls)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                return model
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200):
         
     | 
| 150 | 
         
            +
                """
         
     | 
| 151 | 
         
            +
                Function to load the PDiscoFormer model with ViT backbone
         
     | 
| 152 | 
         
            +
                :param pretrained: Boolean flag to load the pretrained weights
         
     | 
| 153 | 
         
            +
                :param model_dataset: Dataset for which the model is trained
         
     | 
| 154 | 
         
            +
                :param k: Number of unsupervised landmarks the model is trained on
         
     | 
| 155 | 
         
            +
                :param model_url: URL to load the model weights from
         
     | 
| 156 | 
         
            +
                :param img_size: Image size
         
     | 
| 157 | 
         
            +
                :param num_cls: Number of classes in the dataset
         
     | 
| 158 | 
         
            +
                :return: PDiscoFormer model with ViT backbone
         
     | 
| 159 | 
         
            +
                """
         
     | 
| 160 | 
         
            +
                model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
         
     | 
| 161 | 
         
            +
                if pretrained:
         
     | 
| 162 | 
         
            +
                    hub_dir = torch.hub.get_dir()
         
     | 
| 163 | 
         
            +
                    model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}")
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    Path(model_dir).mkdir(parents=True, exist_ok=True)
         
     | 
| 166 | 
         
            +
                    url_path = model_url + str(k) + "_parts_snapshot_best.pt"
         
     | 
| 167 | 
         
            +
                    snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
         
     | 
| 168 | 
         
            +
                    if 'model_state' in snapshot_data:
         
     | 
| 169 | 
         
            +
                        _, state_dict = load_state_dict_pdisco(snapshot_data)
         
     | 
| 170 | 
         
            +
                    else:
         
     | 
| 171 | 
         
            +
                        state_dict = copy.deepcopy(snapshot_data)
         
     | 
| 172 | 
         
            +
                    model.load_state_dict(state_dict, strict=True)
         
     | 
| 173 | 
         
            +
                return model
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555):
         
     | 
| 177 | 
         
            +
                """
         
     | 
| 178 | 
         
            +
                Function to load the PDiscoNet model with ViT backbone
         
     | 
| 179 | 
         
            +
                :param pretrained: Boolean flag to load the pretrained weights
         
     | 
| 180 | 
         
            +
                :param model_dataset: Dataset for which the model is trained
         
     | 
| 181 | 
         
            +
                :param k: Number of unsupervised landmarks the model is trained on
         
     | 
| 182 | 
         
            +
                :param model_url: URL to load the model weights from
         
     | 
| 183 | 
         
            +
                :param img_size: Image size
         
     | 
| 184 | 
         
            +
                :param num_cls: Number of classes in the dataset
         
     | 
| 185 | 
         
            +
                :return: PDiscoNet model with ViT backbone
         
     | 
| 186 | 
         
            +
                """
         
     | 
| 187 | 
         
            +
                model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
         
     | 
| 188 | 
         
            +
                if pretrained:
         
     | 
| 189 | 
         
            +
                    hub_dir = torch.hub.get_dir()
         
     | 
| 190 | 
         
            +
                    model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    Path(model_dir).mkdir(parents=True, exist_ok=True)
         
     | 
| 193 | 
         
            +
                    url_path = model_url + str(k) + "_parts_snapshot_best.pt"
         
     | 
| 194 | 
         
            +
                    snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
         
     | 
| 195 | 
         
            +
                    if 'model_state' in snapshot_data:
         
     | 
| 196 | 
         
            +
                        _, state_dict = load_state_dict_pdisco(snapshot_data)
         
     | 
| 197 | 
         
            +
                    else:
         
     | 
| 198 | 
         
            +
                        state_dict = copy.deepcopy(snapshot_data)
         
     | 
| 199 | 
         
            +
                    model.load_state_dict(state_dict, strict=True)
         
     | 
| 200 | 
         
            +
                return model
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
            def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555):
         
     | 
| 204 | 
         
            +
                """
         
     | 
| 205 | 
         
            +
                Function to load the PDiscoNet model with ResNet-101 backbone
         
     | 
| 206 | 
         
            +
                :param pretrained: Boolean flag to load the pretrained weights
         
     | 
| 207 | 
         
            +
                :param model_dataset: Dataset for which the model is trained
         
     | 
| 208 | 
         
            +
                :param k: Number of unsupervised landmarks the model is trained on
         
     | 
| 209 | 
         
            +
                :param model_url: URL to load the model weights from
         
     | 
| 210 | 
         
            +
                :param num_cls: Number of classes in the dataset
         
     | 
| 211 | 
         
            +
                :return: PDiscoNet model with ResNet-101 backbone
         
     | 
| 212 | 
         
            +
                """
         
     | 
| 213 | 
         
            +
                model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k)
         
     | 
| 214 | 
         
            +
                if pretrained:
         
     | 
| 215 | 
         
            +
                    hub_dir = torch.hub.get_dir()
         
     | 
| 216 | 
         
            +
                    model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    Path(model_dir).mkdir(parents=True, exist_ok=True)
         
     | 
| 219 | 
         
            +
                    url_path = model_url + str(k) + "_parts_snapshot_best.pt"
         
     | 
| 220 | 
         
            +
                    snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
         
     | 
| 221 | 
         
            +
                    if 'model_state' in snapshot_data:
         
     | 
| 222 | 
         
            +
                        _, state_dict = load_state_dict_pdisco(snapshot_data)
         
     | 
| 223 | 
         
            +
                    else:
         
     | 
| 224 | 
         
            +
                        state_dict = copy.deepcopy(snapshot_data)
         
     | 
| 225 | 
         
            +
                    model.load_state_dict(state_dict, strict=True)
         
     | 
| 226 | 
         
            +
                return model
         
     | 
    	
        models/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .individual_landmark_resnet import *
         
     | 
| 2 | 
         
            +
            from .individual_landmark_convnext import *
         
     | 
| 3 | 
         
            +
            from .vit_baseline import *
         
     | 
| 4 | 
         
            +
            from .individual_landmark_vit import *
         
     | 
    	
        models/individual_landmark_convnext.py
    ADDED
    
    | 
         @@ -0,0 +1,110 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch import Tensor
         
     | 
| 3 | 
         
            +
            from torch.nn import Parameter
         
     | 
| 4 | 
         
            +
            from typing import Any
         
     | 
| 5 | 
         
            +
            from layers.independent_mlp import IndependentMLPs
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer
         
     | 
| 9 | 
         
            +
            class IndividualLandmarkConvNext(torch.nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
         
     | 
| 11 | 
         
            +
                             num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3,
         
     | 
| 12 | 
         
            +
                             modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
         
     | 
| 13 | 
         
            +
                             gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
         
     | 
| 14 | 
         
            +
                             classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
         
     | 
| 15 | 
         
            +
                    super().__init__()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    self.num_landmarks = num_landmarks
         
     | 
| 18 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 19 | 
         
            +
                    self.noise_variance = noise_variance
         
     | 
| 20 | 
         
            +
                    self.stem = init_model.stem
         
     | 
| 21 | 
         
            +
                    self.stages = init_model.stages
         
     | 
| 22 | 
         
            +
                    self.feature_dim = sl_channels + fl_channels
         
     | 
| 23 | 
         
            +
                    self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
         
     | 
| 24 | 
         
            +
                    self.gumbel_softmax = gumbel_softmax
         
     | 
| 25 | 
         
            +
                    self.gumbel_softmax_temperature = gumbel_softmax_temperature
         
     | 
| 26 | 
         
            +
                    self.gumbel_softmax_hard = gumbel_softmax_hard
         
     | 
| 27 | 
         
            +
                    self.modulation_type = modulation_type
         
     | 
| 28 | 
         
            +
                    if modulation_type == "layer_norm":
         
     | 
| 29 | 
         
            +
                        self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
         
     | 
| 30 | 
         
            +
                    elif modulation_type == "original":
         
     | 
| 31 | 
         
            +
                        self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
         
     | 
| 32 | 
         
            +
                    elif modulation_type == "parallel_mlp":
         
     | 
| 33 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 34 | 
         
            +
                                                          num_lin_layers=1, act_layer=True, bias=True)
         
     | 
| 35 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_bias":
         
     | 
| 36 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 37 | 
         
            +
                                                          num_lin_layers=1, act_layer=True, bias=False)
         
     | 
| 38 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_act":
         
     | 
| 39 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 40 | 
         
            +
                                                          num_lin_layers=1, act_layer=False, bias=True)
         
     | 
| 41 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_act_no_bias":
         
     | 
| 42 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 43 | 
         
            +
                                                          num_lin_layers=1, act_layer=False, bias=False)
         
     | 
| 44 | 
         
            +
                    elif modulation_type == "none":
         
     | 
| 45 | 
         
            +
                        self.modulation = torch.nn.Identity()
         
     | 
| 46 | 
         
            +
                    else:
         
     | 
| 47 | 
         
            +
                        raise ValueError("modulation_type not implemented")
         
     | 
| 48 | 
         
            +
                    self.modulation_orth = modulation_orth
         
     | 
| 49 | 
         
            +
                    self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
         
     | 
| 50 | 
         
            +
                    self.classifier_type = classifier_type
         
     | 
| 51 | 
         
            +
                    if classifier_type == "independent_mlp":
         
     | 
| 52 | 
         
            +
                        self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
         
     | 
| 53 | 
         
            +
                                                                  num_lin_layers=1, act_layer=False, out_dim=num_classes,
         
     | 
| 54 | 
         
            +
                                                                  bias=False, stack_dim=1)
         
     | 
| 55 | 
         
            +
                    elif classifier_type == "linear":
         
     | 
| 56 | 
         
            +
                        self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
         
     | 
| 57 | 
         
            +
                                                                  bias=False)
         
     | 
| 58 | 
         
            +
                    else:
         
     | 
| 59 | 
         
            +
                        raise ValueError("classifier_type not implemented")
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
         
     | 
| 62 | 
         
            +
                    # Pretrained ConvNeXt part of the model
         
     | 
| 63 | 
         
            +
                    x = self.stem(x)
         
     | 
| 64 | 
         
            +
                    x = self.stages[0](x)
         
     | 
| 65 | 
         
            +
                    x = self.stages[1](x)
         
     | 
| 66 | 
         
            +
                    l3 = self.stages[2](x)
         
     | 
| 67 | 
         
            +
                    x = self.stages[3](l3)
         
     | 
| 68 | 
         
            +
                    x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
         
     | 
| 69 | 
         
            +
                    x = torch.cat((x, l3), dim=1)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    # Compute per landmark attention maps
         
     | 
| 72 | 
         
            +
                    # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
         
     | 
| 73 | 
         
            +
                    batch_size = x.shape[0]
         
     | 
| 74 | 
         
            +
                    ab = self.fc_landmarks(x)
         
     | 
| 75 | 
         
            +
                    b_sq = x.pow(2).sum(1, keepdim=True)
         
     | 
| 76 | 
         
            +
                    b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
         
     | 
| 77 | 
         
            +
                    a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
         
     | 
| 78 | 
         
            +
                                                                                      x.shape[-1]).contiguous()
         
     | 
| 79 | 
         
            +
                    a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    dist = b_sq - 2 * ab + a_sq
         
     | 
| 82 | 
         
            +
                    maps = -dist
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # Softmax so that the attention maps for each pixel add up to 1
         
     | 
| 85 | 
         
            +
                    if self.gumbel_softmax:
         
     | 
| 86 | 
         
            +
                        maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
         
     | 
| 87 | 
         
            +
                                                                  hard=self.gumbel_softmax_hard)  # [B, num_landmarks + 1, H, W]
         
     | 
| 88 | 
         
            +
                    else:
         
     | 
| 89 | 
         
            +
                        maps = torch.nn.functional.softmax(maps, dim=1)  # [B, num_landmarks + 1, H, W]
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # Use maps to get weighted average features per landmark
         
     | 
| 92 | 
         
            +
                    all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
         
     | 
| 93 | 
         
            +
                    if self.noise_variance > 0.0:
         
     | 
| 94 | 
         
            +
                        all_features += torch.randn_like(all_features,
         
     | 
| 95 | 
         
            +
                                                         device=all_features.device) * x.std().detach() * self.noise_variance
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    # Modulate the features
         
     | 
| 98 | 
         
            +
                    if self.modulation_type == "original":
         
     | 
| 99 | 
         
            +
                        all_features_mod = all_features * self.modulation
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        all_features_mod = self.modulation(all_features)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    # Classification based on the landmark features
         
     | 
| 104 | 
         
            +
                    scores = self.fc_class_landmarks(
         
     | 
| 105 | 
         
            +
                        self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
         
     | 
| 106 | 
         
            +
                                                                                                                       1).contiguous()
         
     | 
| 107 | 
         
            +
                    if self.modulation_orth:
         
     | 
| 108 | 
         
            +
                        return all_features_mod, maps, scores, dist
         
     | 
| 109 | 
         
            +
                    else:
         
     | 
| 110 | 
         
            +
                        return all_features, maps, scores, dist
         
     | 
    	
        models/individual_landmark_resnet.py
    ADDED
    
    | 
         @@ -0,0 +1,141 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from https://github.com/robertdvdk/part_detection/blob/main/nets.py
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from torch import Tensor
         
     | 
| 4 | 
         
            +
            from timm.models import create_model
         
     | 
| 5 | 
         
            +
            from torchvision.models import get_model
         
     | 
| 6 | 
         
            +
            from torch.nn import Parameter
         
     | 
| 7 | 
         
            +
            from typing import Any
         
     | 
| 8 | 
         
            +
            from layers.independent_mlp import IndependentMLPs
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            # Baseline model, a modified ResNet with reduced downsampling for a spatially larger feature tensor in the last layer
         
     | 
| 12 | 
         
            +
            class IndividualLandmarkResNet(torch.nn.Module):
         
     | 
| 13 | 
         
            +
                def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
         
     | 
| 14 | 
         
            +
                             num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048,
         
     | 
| 15 | 
         
            +
                             use_torchvision_model: bool = False, part_dropout: float = 0.3,
         
     | 
| 16 | 
         
            +
                             modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
         
     | 
| 17 | 
         
            +
                             gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
         
     | 
| 18 | 
         
            +
                             classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
         
     | 
| 19 | 
         
            +
                    super().__init__()
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    self.num_landmarks = num_landmarks
         
     | 
| 22 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 23 | 
         
            +
                    self.noise_variance = noise_variance
         
     | 
| 24 | 
         
            +
                    self.conv1 = init_model.conv1
         
     | 
| 25 | 
         
            +
                    self.bn1 = init_model.bn1
         
     | 
| 26 | 
         
            +
                    if use_torchvision_model:
         
     | 
| 27 | 
         
            +
                        self.act1 = init_model.relu
         
     | 
| 28 | 
         
            +
                    else:
         
     | 
| 29 | 
         
            +
                        self.act1 = init_model.act1
         
     | 
| 30 | 
         
            +
                    self.maxpool = init_model.maxpool
         
     | 
| 31 | 
         
            +
                    self.layer1 = init_model.layer1
         
     | 
| 32 | 
         
            +
                    self.layer2 = init_model.layer2
         
     | 
| 33 | 
         
            +
                    self.layer3 = init_model.layer3
         
     | 
| 34 | 
         
            +
                    self.layer4 = init_model.layer4
         
     | 
| 35 | 
         
            +
                    self.feature_dim = sl_channels + fl_channels
         
     | 
| 36 | 
         
            +
                    self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
         
     | 
| 37 | 
         
            +
                    self.gumbel_softmax = gumbel_softmax
         
     | 
| 38 | 
         
            +
                    self.gumbel_softmax_temperature = gumbel_softmax_temperature
         
     | 
| 39 | 
         
            +
                    self.gumbel_softmax_hard = gumbel_softmax_hard
         
     | 
| 40 | 
         
            +
                    self.modulation_type = modulation_type
         
     | 
| 41 | 
         
            +
                    if modulation_type == "layer_norm":
         
     | 
| 42 | 
         
            +
                        self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
         
     | 
| 43 | 
         
            +
                    elif modulation_type == "original":
         
     | 
| 44 | 
         
            +
                        self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
         
     | 
| 45 | 
         
            +
                    elif modulation_type == "parallel_mlp":
         
     | 
| 46 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 47 | 
         
            +
                                                          num_lin_layers=1, act_layer=True, bias=True)
         
     | 
| 48 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_bias":
         
     | 
| 49 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 50 | 
         
            +
                                                          num_lin_layers=1, act_layer=True, bias=False)
         
     | 
| 51 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_act":
         
     | 
| 52 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 53 | 
         
            +
                                                          num_lin_layers=1, act_layer=False, bias=True)
         
     | 
| 54 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_act_no_bias":
         
     | 
| 55 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 56 | 
         
            +
                                                          num_lin_layers=1, act_layer=False, bias=False)
         
     | 
| 57 | 
         
            +
                    elif modulation_type == "none":
         
     | 
| 58 | 
         
            +
                        self.modulation = torch.nn.Identity()
         
     | 
| 59 | 
         
            +
                    else:
         
     | 
| 60 | 
         
            +
                        raise ValueError("modulation_type not implemented")
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    self.modulation_orth = modulation_orth
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
         
     | 
| 65 | 
         
            +
                    self.classifier_type = classifier_type
         
     | 
| 66 | 
         
            +
                    if classifier_type == "independent_mlp":
         
     | 
| 67 | 
         
            +
                        self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
         
     | 
| 68 | 
         
            +
                                                                  num_lin_layers=1, act_layer=False, out_dim=num_classes,
         
     | 
| 69 | 
         
            +
                                                                  bias=False, stack_dim=1)
         
     | 
| 70 | 
         
            +
                    elif classifier_type == "linear":
         
     | 
| 71 | 
         
            +
                        self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
         
     | 
| 72 | 
         
            +
                                                                  bias=False)
         
     | 
| 73 | 
         
            +
                    else:
         
     | 
| 74 | 
         
            +
                        raise ValueError("classifier_type not implemented")
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
         
     | 
| 77 | 
         
            +
                    # Pretrained ResNet part of the model
         
     | 
| 78 | 
         
            +
                    x = self.conv1(x)
         
     | 
| 79 | 
         
            +
                    x = self.bn1(x)
         
     | 
| 80 | 
         
            +
                    x = self.act1(x)
         
     | 
| 81 | 
         
            +
                    x = self.maxpool(x)
         
     | 
| 82 | 
         
            +
                    x = self.layer1(x)
         
     | 
| 83 | 
         
            +
                    x = self.layer2(x)
         
     | 
| 84 | 
         
            +
                    l3 = self.layer3(x)
         
     | 
| 85 | 
         
            +
                    x = self.layer4(l3)
         
     | 
| 86 | 
         
            +
                    x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
         
     | 
| 87 | 
         
            +
                    x = torch.cat((x, l3), dim=1)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    # Compute per landmark attention maps
         
     | 
| 90 | 
         
            +
                    # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
         
     | 
| 91 | 
         
            +
                    batch_size = x.shape[0]
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    ab = self.fc_landmarks(x)
         
     | 
| 94 | 
         
            +
                    b_sq = x.pow(2).sum(1, keepdim=True)
         
     | 
| 95 | 
         
            +
                    b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
         
     | 
| 96 | 
         
            +
                    a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
         
     | 
| 97 | 
         
            +
                                                                                      x.shape[-1]).contiguous()
         
     | 
| 98 | 
         
            +
                    a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    dist = b_sq - 2 * ab + a_sq
         
     | 
| 101 | 
         
            +
                    maps = -dist
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    # Softmax so that the attention maps for each pixel add up to 1
         
     | 
| 104 | 
         
            +
                    if self.gumbel_softmax:
         
     | 
| 105 | 
         
            +
                        maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
         
     | 
| 106 | 
         
            +
                                                                  hard=self.gumbel_softmax_hard)  # [B, num_landmarks + 1, H, W]
         
     | 
| 107 | 
         
            +
                    else:
         
     | 
| 108 | 
         
            +
                        maps = torch.nn.functional.softmax(maps, dim=1)  # [B, num_landmarks + 1, H, W]
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    # Use maps to get weighted average features per landmark
         
     | 
| 111 | 
         
            +
                    all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
         
     | 
| 112 | 
         
            +
                    if self.noise_variance > 0.0:
         
     | 
| 113 | 
         
            +
                        all_features += torch.randn_like(all_features,
         
     | 
| 114 | 
         
            +
                                                         device=all_features.device) * x.std().detach() * self.noise_variance
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    # Modulate the features
         
     | 
| 117 | 
         
            +
                    if self.modulation_type == "original":
         
     | 
| 118 | 
         
            +
                        all_features_mod = all_features * self.modulation
         
     | 
| 119 | 
         
            +
                    else:
         
     | 
| 120 | 
         
            +
                        all_features_mod = self.modulation(all_features)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    # Classification based on the landmark features
         
     | 
| 123 | 
         
            +
                    scores = self.fc_class_landmarks(
         
     | 
| 124 | 
         
            +
                        self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
         
     | 
| 125 | 
         
            +
                                                                                                                       1).contiguous()
         
     | 
| 126 | 
         
            +
                    if self.modulation_orth:
         
     | 
| 127 | 
         
            +
                        return all_features_mod, maps, scores, dist
         
     | 
| 128 | 
         
            +
                    else:
         
     | 
| 129 | 
         
            +
                        return all_features, maps, scores, dist
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            def pdisconet_resnet_torchvision_bb(backbone, num_cls=200, k=8, **kwargs):
         
     | 
| 133 | 
         
            +
                base_model = get_model(backbone)
         
     | 
| 134 | 
         
            +
                return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
         
     | 
| 135 | 
         
            +
                                                modulation_type="original")
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
            def pdisconet_resnet_timm_bb(backbone, num_cls=200, k=8, output_stride=32, **kwargs):
         
     | 
| 139 | 
         
            +
                base_model = create_model(backbone, pretrained=True, output_stride=output_stride)
         
     | 
| 140 | 
         
            +
                return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
         
     | 
| 141 | 
         
            +
                                                modulation_type="original")
         
     | 
    	
        models/individual_landmark_vit.py
    ADDED
    
    | 
         @@ -0,0 +1,366 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
         
     | 
| 2 | 
         
            +
            from pathlib import Path
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from torch import Tensor
         
     | 
| 7 | 
         
            +
            from typing import Any, Union, Sequence, Optional, Dict
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from timm.models import create_model
         
     | 
| 12 | 
         
            +
            from timm.models.vision_transformer import Block, Attention
         
     | 
| 13 | 
         
            +
            from utils.misc_utils import compute_attention
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
         
     | 
| 16 | 
         
            +
            from layers.independent_mlp import IndependentMLPs
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            SAFETENSORS_SINGLE_FILE = "model.safetensors"
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
         
     | 
| 22 | 
         
            +
                                        pipeline_tag='image-classification',
         
     | 
| 23 | 
         
            +
                                        repo_url='https://github.com/ananthu-aniraj/pdiscoformer'):
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, num_classes: int = 200,
         
     | 
| 26 | 
         
            +
                             part_dropout: float = 0.3, return_transformer_qkv: bool = False,
         
     | 
| 27 | 
         
            +
                             modulation_type: str = "original", gumbel_softmax: bool = False,
         
     | 
| 28 | 
         
            +
                             gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
         
     | 
| 29 | 
         
            +
                             modulation_orth: bool = False, classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
         
     | 
| 30 | 
         
            +
                    super().__init__()
         
     | 
| 31 | 
         
            +
                    self.num_landmarks = num_landmarks
         
     | 
| 32 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 33 | 
         
            +
                    self.noise_variance = noise_variance
         
     | 
| 34 | 
         
            +
                    self.num_prefix_tokens = init_model.num_prefix_tokens
         
     | 
| 35 | 
         
            +
                    self.num_reg_tokens = init_model.num_reg_tokens
         
     | 
| 36 | 
         
            +
                    self.has_class_token = init_model.has_class_token
         
     | 
| 37 | 
         
            +
                    self.no_embed_class = init_model.no_embed_class
         
     | 
| 38 | 
         
            +
                    self.cls_token = init_model.cls_token
         
     | 
| 39 | 
         
            +
                    self.reg_token = init_model.reg_token
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.feature_dim = init_model.embed_dim
         
     | 
| 42 | 
         
            +
                    self.patch_embed = init_model.patch_embed
         
     | 
| 43 | 
         
            +
                    self.pos_embed = init_model.pos_embed
         
     | 
| 44 | 
         
            +
                    self.pos_drop = init_model.pos_drop
         
     | 
| 45 | 
         
            +
                    self.norm_pre = init_model.norm_pre
         
     | 
| 46 | 
         
            +
                    self.blocks = init_model.blocks
         
     | 
| 47 | 
         
            +
                    self.norm = init_model.norm
         
     | 
| 48 | 
         
            +
                    self.return_transformer_qkv = return_transformer_qkv
         
     | 
| 49 | 
         
            +
                    self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
         
     | 
| 50 | 
         
            +
                    self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    self.unflatten = nn.Unflatten(1, (self.h_fmap, self.w_fmap))
         
     | 
| 53 | 
         
            +
                    self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
         
     | 
| 54 | 
         
            +
                    self.gumbel_softmax = gumbel_softmax
         
     | 
| 55 | 
         
            +
                    self.gumbel_softmax_temperature = gumbel_softmax_temperature
         
     | 
| 56 | 
         
            +
                    self.gumbel_softmax_hard = gumbel_softmax_hard
         
     | 
| 57 | 
         
            +
                    self.modulation_type = modulation_type
         
     | 
| 58 | 
         
            +
                    if modulation_type == "layer_norm":
         
     | 
| 59 | 
         
            +
                        self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
         
     | 
| 60 | 
         
            +
                    elif modulation_type == "original":
         
     | 
| 61 | 
         
            +
                        self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
         
     | 
| 62 | 
         
            +
                    elif modulation_type == "parallel_mlp":
         
     | 
| 63 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 64 | 
         
            +
                                                          num_lin_layers=1, act_layer=True, bias=True)
         
     | 
| 65 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_bias":
         
     | 
| 66 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 67 | 
         
            +
                                                          num_lin_layers=1, act_layer=True, bias=False)
         
     | 
| 68 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_act":
         
     | 
| 69 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 70 | 
         
            +
                                                          num_lin_layers=1, act_layer=False, bias=True)
         
     | 
| 71 | 
         
            +
                    elif modulation_type == "parallel_mlp_no_act_no_bias":
         
     | 
| 72 | 
         
            +
                        self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
         
     | 
| 73 | 
         
            +
                                                          num_lin_layers=1, act_layer=False, bias=False)
         
     | 
| 74 | 
         
            +
                    elif modulation_type == "none":
         
     | 
| 75 | 
         
            +
                        self.modulation = torch.nn.Identity()
         
     | 
| 76 | 
         
            +
                    else:
         
     | 
| 77 | 
         
            +
                        raise ValueError("modulation_type not implemented")
         
     | 
| 78 | 
         
            +
                    self.modulation_orth = modulation_orth
         
     | 
| 79 | 
         
            +
                    self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
         
     | 
| 80 | 
         
            +
                    self.classifier_type = classifier_type
         
     | 
| 81 | 
         
            +
                    if classifier_type == "independent_mlp":
         
     | 
| 82 | 
         
            +
                        self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
         
     | 
| 83 | 
         
            +
                                                                  num_lin_layers=1, act_layer=False, out_dim=num_classes,
         
     | 
| 84 | 
         
            +
                                                                  bias=False, stack_dim=1)
         
     | 
| 85 | 
         
            +
                    elif classifier_type == "linear":
         
     | 
| 86 | 
         
            +
                        self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
         
     | 
| 87 | 
         
            +
                                                                  bias=False)
         
     | 
| 88 | 
         
            +
                    else:
         
     | 
| 89 | 
         
            +
                        raise ValueError("classifier_type not implemented")
         
     | 
| 90 | 
         
            +
                    self.convert_blocks_and_attention()
         
     | 
| 91 | 
         
            +
                    self._init_weights()
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                def _init_weights_head(self):
         
     | 
| 94 | 
         
            +
                    # Initialize weights with a truncated normal distribution
         
     | 
| 95 | 
         
            +
                    if self.classifier_type == "independent_mlp":
         
     | 
| 96 | 
         
            +
                        self.fc_class_landmarks.reset_weights()
         
     | 
| 97 | 
         
            +
                    else:
         
     | 
| 98 | 
         
            +
                        torch.nn.init.trunc_normal_(self.fc_class_landmarks.weight, std=0.02)
         
     | 
| 99 | 
         
            +
                        if self.fc_class_landmarks.bias is not None:
         
     | 
| 100 | 
         
            +
                            torch.nn.init.zeros_(self.fc_class_landmarks.bias)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def _init_weights(self):
         
     | 
| 103 | 
         
            +
                    self._init_weights_head()
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def convert_blocks_and_attention(self):
         
     | 
| 106 | 
         
            +
                    for module in self.modules():
         
     | 
| 107 | 
         
            +
                        if isinstance(module, Block):
         
     | 
| 108 | 
         
            +
                            module.__class__ = BlockWQKVReturn
         
     | 
| 109 | 
         
            +
                        elif isinstance(module, Attention):
         
     | 
| 110 | 
         
            +
                            module.__class__ = AttentionWQKVReturn
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 113 | 
         
            +
                    pos_embed = self.pos_embed
         
     | 
| 114 | 
         
            +
                    to_cat = []
         
     | 
| 115 | 
         
            +
                    if self.cls_token is not None:
         
     | 
| 116 | 
         
            +
                        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
         
     | 
| 117 | 
         
            +
                    if self.reg_token is not None:
         
     | 
| 118 | 
         
            +
                        to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
         
     | 
| 119 | 
         
            +
                    if self.no_embed_class:
         
     | 
| 120 | 
         
            +
                        # deit-3, updated JAX (big vision)
         
     | 
| 121 | 
         
            +
                        # position embedding does not overlap with class token, add then concat
         
     | 
| 122 | 
         
            +
                        x = x + pos_embed
         
     | 
| 123 | 
         
            +
                        if to_cat:
         
     | 
| 124 | 
         
            +
                            x = torch.cat(to_cat + [x], dim=1)
         
     | 
| 125 | 
         
            +
                    else:
         
     | 
| 126 | 
         
            +
                        # original timm, JAX, and deit vit impl
         
     | 
| 127 | 
         
            +
                        # pos_embed has entry for class token, concat then add
         
     | 
| 128 | 
         
            +
                        if to_cat:
         
     | 
| 129 | 
         
            +
                            x = torch.cat(to_cat + [x], dim=1)
         
     | 
| 130 | 
         
            +
                        x = x + pos_embed
         
     | 
| 131 | 
         
            +
                    return self.pos_drop(x)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, int | Any] | tuple[Any, Any, Any, Any, int | Any]:
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    # Position Embedding
         
     | 
| 138 | 
         
            +
                    x = self._pos_embed(x)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # Forward pass through transformer
         
     | 
| 141 | 
         
            +
                    x = self.norm_pre(x)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    x = self.blocks(x)
         
     | 
| 144 | 
         
            +
                    x = self.norm(x)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    # Compute per landmark attention maps
         
     | 
| 147 | 
         
            +
                    # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps vit, a = convolution kernel
         
     | 
| 148 | 
         
            +
                    batch_size = x.shape[0]
         
     | 
| 149 | 
         
            +
                    x = x[:, self.num_prefix_tokens:, :]  # [B, num_patch_tokens, embed_dim]
         
     | 
| 150 | 
         
            +
                    x = self.unflatten(x)  # [B, H, W, embed_dim]
         
     | 
| 151 | 
         
            +
                    x = x.permute(0, 3, 1, 2).contiguous()  # [B, embed_dim, H, W]
         
     | 
| 152 | 
         
            +
                    ab = self.fc_landmarks(x)  # [B, num_landmarks + 1, H, W]
         
     | 
| 153 | 
         
            +
                    b_sq = x.pow(2).sum(1, keepdim=True)
         
     | 
| 154 | 
         
            +
                    b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
         
     | 
| 155 | 
         
            +
                    a_sq = self.fc_landmarks.weight.pow(2).sum(1, keepdim=True).expand(-1, batch_size, x.shape[-2],
         
     | 
| 156 | 
         
            +
                                                                                       x.shape[-1]).contiguous()
         
     | 
| 157 | 
         
            +
                    a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    dist = b_sq - 2 * ab + a_sq
         
     | 
| 160 | 
         
            +
                    maps = -dist
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    # Softmax so that the attention maps for each pixel add up to 1
         
     | 
| 163 | 
         
            +
                    if self.gumbel_softmax:
         
     | 
| 164 | 
         
            +
                        maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
         
     | 
| 165 | 
         
            +
                                                                  hard=self.gumbel_softmax_hard)  # [B, num_landmarks + 1, H, W]
         
     | 
| 166 | 
         
            +
                    else:
         
     | 
| 167 | 
         
            +
                        maps = torch.nn.functional.softmax(maps, dim=1)  # [B, num_landmarks + 1, H, W]
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    # Use maps to get weighted average features per landmark
         
     | 
| 170 | 
         
            +
                    all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
         
     | 
| 171 | 
         
            +
                    if self.noise_variance > 0.0:
         
     | 
| 172 | 
         
            +
                        all_features += torch.randn_like(all_features,
         
     | 
| 173 | 
         
            +
                                                         device=all_features.device) * x.std().detach() * self.noise_variance
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    all_features = all_features.mean(-1).mean(-1).contiguous()  # [B, embed_dim, num_landmarks + 1]
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    # Modulate the features
         
     | 
| 178 | 
         
            +
                    if self.modulation_type == "original":
         
     | 
| 179 | 
         
            +
                        all_features_mod = all_features * self.modulation  # [B, embed_dim, num_landmarks + 1]
         
     | 
| 180 | 
         
            +
                    else:
         
     | 
| 181 | 
         
            +
                        all_features_mod = self.modulation(all_features)  # [B, embed_dim, num_landmarks + 1]
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    # Classification based on the landmark features
         
     | 
| 184 | 
         
            +
                    scores = self.fc_class_landmarks(
         
     | 
| 185 | 
         
            +
                        self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
         
     | 
| 186 | 
         
            +
                                                                                                                       1).contiguous()
         
     | 
| 187 | 
         
            +
                    if self.modulation_orth:
         
     | 
| 188 | 
         
            +
                        return all_features_mod, maps, scores, dist
         
     | 
| 189 | 
         
            +
                    else:
         
     | 
| 190 | 
         
            +
                        return all_features, maps, scores, dist
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                def get_specific_intermediate_layer(
         
     | 
| 193 | 
         
            +
                        self,
         
     | 
| 194 | 
         
            +
                        x: torch.Tensor,
         
     | 
| 195 | 
         
            +
                        n: int = 1,
         
     | 
| 196 | 
         
            +
                        return_qkv: bool = False,
         
     | 
| 197 | 
         
            +
                        return_att_weights: bool = False,
         
     | 
| 198 | 
         
            +
                ):
         
     | 
| 199 | 
         
            +
                    num_blocks = len(self.blocks)
         
     | 
| 200 | 
         
            +
                    attn_weights = []
         
     | 
| 201 | 
         
            +
                    if n >= num_blocks:
         
     | 
| 202 | 
         
            +
                        raise ValueError(f"n must be less than {num_blocks}")
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    # forward pass
         
     | 
| 205 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 206 | 
         
            +
                    x = self._pos_embed(x)
         
     | 
| 207 | 
         
            +
                    x = self.norm_pre(x)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                    if n == -1:
         
     | 
| 210 | 
         
            +
                        if return_qkv:
         
     | 
| 211 | 
         
            +
                            raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
         
     | 
| 212 | 
         
            +
                        else:
         
     | 
| 213 | 
         
            +
                            return x
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 216 | 
         
            +
                        if self.return_transformer_qkv:
         
     | 
| 217 | 
         
            +
                            x, qkv = blk(x, return_qkv=True)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                            if return_att_weights:
         
     | 
| 220 | 
         
            +
                                attn_weight, _ = compute_attention(qkv)
         
     | 
| 221 | 
         
            +
                                attn_weights.append(attn_weight.detach())
         
     | 
| 222 | 
         
            +
                        else:
         
     | 
| 223 | 
         
            +
                            x = blk(x)
         
     | 
| 224 | 
         
            +
                        if i == n:
         
     | 
| 225 | 
         
            +
                            output = x.clone()
         
     | 
| 226 | 
         
            +
                            if self.return_transformer_qkv and return_qkv:
         
     | 
| 227 | 
         
            +
                                qkv_output = qkv.clone()
         
     | 
| 228 | 
         
            +
                            break
         
     | 
| 229 | 
         
            +
                    if self.return_transformer_qkv and return_qkv and return_att_weights:
         
     | 
| 230 | 
         
            +
                        return output, qkv_output, attn_weights
         
     | 
| 231 | 
         
            +
                    elif self.return_transformer_qkv and return_qkv:
         
     | 
| 232 | 
         
            +
                        return output, qkv_output
         
     | 
| 233 | 
         
            +
                    elif self.return_transformer_qkv and return_att_weights:
         
     | 
| 234 | 
         
            +
                        return output, attn_weights
         
     | 
| 235 | 
         
            +
                    else:
         
     | 
| 236 | 
         
            +
                        return output
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                def _intermediate_layers(
         
     | 
| 239 | 
         
            +
                        self,
         
     | 
| 240 | 
         
            +
                        x: torch.Tensor,
         
     | 
| 241 | 
         
            +
                        n: Union[int, Sequence] = 1,
         
     | 
| 242 | 
         
            +
                ):
         
     | 
| 243 | 
         
            +
                    outputs, num_blocks = [], len(self.blocks)
         
     | 
| 244 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 245 | 
         
            +
                        qkv_outputs = []
         
     | 
| 246 | 
         
            +
                    take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    # forward pass
         
     | 
| 249 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 250 | 
         
            +
                    x = self._pos_embed(x)
         
     | 
| 251 | 
         
            +
                    x = self.norm_pre(x)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 254 | 
         
            +
                        if self.return_transformer_qkv:
         
     | 
| 255 | 
         
            +
                            x, qkv = blk(x, return_qkv=True)
         
     | 
| 256 | 
         
            +
                        else:
         
     | 
| 257 | 
         
            +
                            x = blk(x)
         
     | 
| 258 | 
         
            +
                        if i in take_indices:
         
     | 
| 259 | 
         
            +
                            outputs.append(x)
         
     | 
| 260 | 
         
            +
                            if self.return_transformer_qkv:
         
     | 
| 261 | 
         
            +
                                qkv_outputs.append(qkv)
         
     | 
| 262 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 263 | 
         
            +
                        return outputs, qkv_outputs
         
     | 
| 264 | 
         
            +
                    else:
         
     | 
| 265 | 
         
            +
                        return outputs
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                def get_intermediate_layers(
         
     | 
| 268 | 
         
            +
                        self,
         
     | 
| 269 | 
         
            +
                        x: torch.Tensor,
         
     | 
| 270 | 
         
            +
                        n: Union[int, Sequence] = 1,
         
     | 
| 271 | 
         
            +
                        reshape: bool = False,
         
     | 
| 272 | 
         
            +
                        return_prefix_tokens: bool = False,
         
     | 
| 273 | 
         
            +
                        norm: bool = False,
         
     | 
| 274 | 
         
            +
                ) -> tuple[tuple, Any]:
         
     | 
| 275 | 
         
            +
                    """ Intermediate layer accessor (NOTE: This is a WIP experiment).
         
     | 
| 276 | 
         
            +
                    Inspired by DINO / DINOv2 interface
         
     | 
| 277 | 
         
            +
                    """
         
     | 
| 278 | 
         
            +
                    # take last n blocks if n is an int, if in is a sequence, select by matching indices
         
     | 
| 279 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 280 | 
         
            +
                        outputs, qkv = self._intermediate_layers(x, n)
         
     | 
| 281 | 
         
            +
                    else:
         
     | 
| 282 | 
         
            +
                        outputs = self._intermediate_layers(x, n)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    if norm:
         
     | 
| 285 | 
         
            +
                        outputs = [self.norm(out) for out in outputs]
         
     | 
| 286 | 
         
            +
                    prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
         
     | 
| 287 | 
         
            +
                    outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    if reshape:
         
     | 
| 290 | 
         
            +
                        grid_size = self.patch_embed.grid_size
         
     | 
| 291 | 
         
            +
                        outputs = [
         
     | 
| 292 | 
         
            +
                            out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
         
     | 
| 293 | 
         
            +
                            for out in outputs
         
     | 
| 294 | 
         
            +
                        ]
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    if return_prefix_tokens:
         
     | 
| 297 | 
         
            +
                        return_out = tuple(zip(outputs, prefix_tokens))
         
     | 
| 298 | 
         
            +
                    else:
         
     | 
| 299 | 
         
            +
                        return_out = tuple(outputs)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 302 | 
         
            +
                        return return_out, qkv
         
     | 
| 303 | 
         
            +
                    else:
         
     | 
| 304 | 
         
            +
                        return return_out
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                @classmethod
         
     | 
| 307 | 
         
            +
                def _from_pretrained(
         
     | 
| 308 | 
         
            +
                        cls,
         
     | 
| 309 | 
         
            +
                        *,
         
     | 
| 310 | 
         
            +
                        model_id: str,
         
     | 
| 311 | 
         
            +
                        revision: Optional[str],
         
     | 
| 312 | 
         
            +
                        cache_dir: Optional[Union[str, Path]],
         
     | 
| 313 | 
         
            +
                        force_download: bool,
         
     | 
| 314 | 
         
            +
                        proxies: Optional[Dict],
         
     | 
| 315 | 
         
            +
                        resume_download: Optional[bool],
         
     | 
| 316 | 
         
            +
                        local_files_only: bool,
         
     | 
| 317 | 
         
            +
                        token: Union[str, bool, None],
         
     | 
| 318 | 
         
            +
                        map_location: str = "cpu",
         
     | 
| 319 | 
         
            +
                        strict: bool = False,
         
     | 
| 320 | 
         
            +
                        timm_backbone: str = "hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m",
         
     | 
| 321 | 
         
            +
                        input_size: int = 518,
         
     | 
| 322 | 
         
            +
                        **model_kwargs):
         
     | 
| 323 | 
         
            +
                    base_model = create_model(timm_backbone, pretrained=False, img_size=input_size)
         
     | 
| 324 | 
         
            +
                    model = cls(base_model, **model_kwargs)
         
     | 
| 325 | 
         
            +
                    if os.path.isdir(model_id):
         
     | 
| 326 | 
         
            +
                        print("Loading weights from local directory")
         
     | 
| 327 | 
         
            +
                        model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
         
     | 
| 328 | 
         
            +
                        return cls._load_as_safetensor(model, model_file, map_location, strict)
         
     | 
| 329 | 
         
            +
                    else:
         
     | 
| 330 | 
         
            +
                        model_file = hf_hub_download(
         
     | 
| 331 | 
         
            +
                            repo_id=model_id,
         
     | 
| 332 | 
         
            +
                            filename=SAFETENSORS_SINGLE_FILE,
         
     | 
| 333 | 
         
            +
                            revision=revision,
         
     | 
| 334 | 
         
            +
                            cache_dir=cache_dir,
         
     | 
| 335 | 
         
            +
                            force_download=force_download,
         
     | 
| 336 | 
         
            +
                            proxies=proxies,
         
     | 
| 337 | 
         
            +
                            resume_download=resume_download,
         
     | 
| 338 | 
         
            +
                            token=token,
         
     | 
| 339 | 
         
            +
                            local_files_only=local_files_only,
         
     | 
| 340 | 
         
            +
                        )
         
     | 
| 341 | 
         
            +
                        return cls._load_as_safetensor(model, model_file, map_location, strict)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
            def pdiscoformer_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
         
     | 
| 345 | 
         
            +
                base_model = create_model(
         
     | 
| 346 | 
         
            +
                    backbone,
         
     | 
| 347 | 
         
            +
                    pretrained=False,
         
     | 
| 348 | 
         
            +
                    img_size=img_size,
         
     | 
| 349 | 
         
            +
                )
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
         
     | 
| 352 | 
         
            +
                                              modulation_type="layer_norm", gumbel_softmax=True,
         
     | 
| 353 | 
         
            +
                                              modulation_orth=True)
         
     | 
| 354 | 
         
            +
                return model
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
            def pdisconet_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
         
     | 
| 358 | 
         
            +
                base_model = create_model(
         
     | 
| 359 | 
         
            +
                    backbone,
         
     | 
| 360 | 
         
            +
                    pretrained=False,
         
     | 
| 361 | 
         
            +
                    img_size=img_size,
         
     | 
| 362 | 
         
            +
                )
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
         
     | 
| 365 | 
         
            +
                                              modulation_type="original")
         
     | 
| 366 | 
         
            +
                return model
         
     | 
    	
        models/vit_baseline.py
    ADDED
    
    | 
         @@ -0,0 +1,239 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            from typing import Tuple, Union, Sequence, Any
         
     | 
| 5 | 
         
            +
            from timm.layers import trunc_normal_
         
     | 
| 6 | 
         
            +
            from timm.models.vision_transformer import Block, Attention
         
     | 
| 7 | 
         
            +
            from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from utils.misc_utils import compute_attention
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class BaselineViT(torch.nn.Module):
         
     | 
| 13 | 
         
            +
                """
         
     | 
| 14 | 
         
            +
                Modifications:
         
     | 
| 15 | 
         
            +
                - Use PDiscoBlock instead of Block
         
     | 
| 16 | 
         
            +
                - Use PDiscoAttention instead of Attention
         
     | 
| 17 | 
         
            +
                - Return the mean of k over heads from attention
         
     | 
| 18 | 
         
            +
                - Option to use only class tokens or only patch tokens or both (concat) for classification
         
     | 
| 19 | 
         
            +
                """
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def __init__(self, init_model: torch.nn.Module, num_classes: int,
         
     | 
| 22 | 
         
            +
                             class_tokens_only: bool = False,
         
     | 
| 23 | 
         
            +
                             patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None:
         
     | 
| 24 | 
         
            +
                    super().__init__()
         
     | 
| 25 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 26 | 
         
            +
                    self.class_tokens_only = class_tokens_only
         
     | 
| 27 | 
         
            +
                    self.patch_tokens_only = patch_tokens_only
         
     | 
| 28 | 
         
            +
                    self.num_prefix_tokens = init_model.num_prefix_tokens
         
     | 
| 29 | 
         
            +
                    self.num_reg_tokens = init_model.num_reg_tokens
         
     | 
| 30 | 
         
            +
                    self.has_class_token = init_model.has_class_token
         
     | 
| 31 | 
         
            +
                    self.no_embed_class = init_model.no_embed_class
         
     | 
| 32 | 
         
            +
                    self.cls_token = init_model.cls_token
         
     | 
| 33 | 
         
            +
                    self.reg_token = init_model.reg_token
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    self.patch_embed = init_model.patch_embed
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    self.pos_embed = init_model.pos_embed
         
     | 
| 38 | 
         
            +
                    self.pos_drop = init_model.pos_drop
         
     | 
| 39 | 
         
            +
                    self.part_embed = nn.Identity()
         
     | 
| 40 | 
         
            +
                    self.patch_prune = nn.Identity()
         
     | 
| 41 | 
         
            +
                    self.norm_pre = init_model.norm_pre
         
     | 
| 42 | 
         
            +
                    self.blocks = init_model.blocks
         
     | 
| 43 | 
         
            +
                    self.norm = init_model.norm
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    self.fc_norm = init_model.fc_norm
         
     | 
| 46 | 
         
            +
                    if class_tokens_only or patch_tokens_only:
         
     | 
| 47 | 
         
            +
                        self.head = nn.Linear(init_model.embed_dim, num_classes)
         
     | 
| 48 | 
         
            +
                    else:
         
     | 
| 49 | 
         
            +
                        self.head = nn.Linear(init_model.embed_dim * 2, num_classes)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
         
     | 
| 52 | 
         
            +
                    self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    self.return_transformer_qkv = return_transformer_qkv
         
     | 
| 55 | 
         
            +
                    self.convert_blocks_and_attention()
         
     | 
| 56 | 
         
            +
                    self._init_weights_head()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def convert_blocks_and_attention(self):
         
     | 
| 59 | 
         
            +
                    for module in self.modules():
         
     | 
| 60 | 
         
            +
                        if isinstance(module, Block):
         
     | 
| 61 | 
         
            +
                            module.__class__ = BlockWQKVReturn
         
     | 
| 62 | 
         
            +
                        elif isinstance(module, Attention):
         
     | 
| 63 | 
         
            +
                            module.__class__ = AttentionWQKVReturn
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 66 | 
         
            +
                    pos_embed = self.pos_embed
         
     | 
| 67 | 
         
            +
                    to_cat = []
         
     | 
| 68 | 
         
            +
                    if self.cls_token is not None:
         
     | 
| 69 | 
         
            +
                        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
         
     | 
| 70 | 
         
            +
                    if self.reg_token is not None:
         
     | 
| 71 | 
         
            +
                        to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
         
     | 
| 72 | 
         
            +
                    if self.no_embed_class:
         
     | 
| 73 | 
         
            +
                        # deit-3, updated JAX (big vision)
         
     | 
| 74 | 
         
            +
                        # position embedding does not overlap with class token, add then concat
         
     | 
| 75 | 
         
            +
                        x = x + pos_embed
         
     | 
| 76 | 
         
            +
                        if to_cat:
         
     | 
| 77 | 
         
            +
                            x = torch.cat(to_cat + [x], dim=1)
         
     | 
| 78 | 
         
            +
                    else:
         
     | 
| 79 | 
         
            +
                        # original timm, JAX, and deit vit impl
         
     | 
| 80 | 
         
            +
                        # pos_embed has entry for class token, concat then add
         
     | 
| 81 | 
         
            +
                        if to_cat:
         
     | 
| 82 | 
         
            +
                            x = torch.cat(to_cat + [x], dim=1)
         
     | 
| 83 | 
         
            +
                        x = x + pos_embed
         
     | 
| 84 | 
         
            +
                    return self.pos_drop(x)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                def _init_weights_head(self):
         
     | 
| 87 | 
         
            +
                    trunc_normal_(self.head.weight, std=.02)
         
     | 
| 88 | 
         
            +
                    if self.head.bias is not None:
         
     | 
| 89 | 
         
            +
                        nn.init.constant_(self.head.bias, 0.)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    # Position Embedding
         
     | 
| 96 | 
         
            +
                    x = self._pos_embed(x)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    x = self.part_embed(x)
         
     | 
| 99 | 
         
            +
                    x = self.patch_prune(x)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    # Forward pass through transformer
         
     | 
| 102 | 
         
            +
                    x = self.norm_pre(x)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 105 | 
         
            +
                        # Return keys of last attention layer
         
     | 
| 106 | 
         
            +
                        for i, blk in enumerate(self.blocks):
         
     | 
| 107 | 
         
            +
                            x, qkv = blk(x, return_qkv=True)
         
     | 
| 108 | 
         
            +
                    else:
         
     | 
| 109 | 
         
            +
                        x = self.blocks(x)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    x = self.norm(x)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    # Classification head
         
     | 
| 114 | 
         
            +
                    x = self.fc_norm(x)
         
     | 
| 115 | 
         
            +
                    if self.class_tokens_only:  # only use class token
         
     | 
| 116 | 
         
            +
                        x = x[:, 0, :]
         
     | 
| 117 | 
         
            +
                    elif self.patch_tokens_only:  # only use patch tokens
         
     | 
| 118 | 
         
            +
                        x = x[:, self.num_prefix_tokens:, :].mean(dim=1)
         
     | 
| 119 | 
         
            +
                    else:
         
     | 
| 120 | 
         
            +
                        x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1)
         
     | 
| 121 | 
         
            +
                    x = self.head(x)
         
     | 
| 122 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 123 | 
         
            +
                        return x, qkv
         
     | 
| 124 | 
         
            +
                    else:
         
     | 
| 125 | 
         
            +
                        return x
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def get_specific_intermediate_layer(
         
     | 
| 128 | 
         
            +
                        self,
         
     | 
| 129 | 
         
            +
                        x: torch.Tensor,
         
     | 
| 130 | 
         
            +
                        n: int = 1,
         
     | 
| 131 | 
         
            +
                        return_qkv: bool = False,
         
     | 
| 132 | 
         
            +
                        return_att_weights: bool = False,
         
     | 
| 133 | 
         
            +
                ):
         
     | 
| 134 | 
         
            +
                    num_blocks = len(self.blocks)
         
     | 
| 135 | 
         
            +
                    attn_weights = []
         
     | 
| 136 | 
         
            +
                    if n >= num_blocks:
         
     | 
| 137 | 
         
            +
                        raise ValueError(f"n must be less than {num_blocks}")
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    # forward pass
         
     | 
| 140 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 141 | 
         
            +
                    x = self._pos_embed(x)
         
     | 
| 142 | 
         
            +
                    x = self.norm_pre(x)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    if n == -1:
         
     | 
| 145 | 
         
            +
                        if return_qkv:
         
     | 
| 146 | 
         
            +
                            raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
         
     | 
| 147 | 
         
            +
                        else:
         
     | 
| 148 | 
         
            +
                            return x
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 151 | 
         
            +
                        if self.return_transformer_qkv:
         
     | 
| 152 | 
         
            +
                            x, qkv = blk(x, return_qkv=True)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                            if return_att_weights:
         
     | 
| 155 | 
         
            +
                                attn_weight, _ = compute_attention(qkv)
         
     | 
| 156 | 
         
            +
                                attn_weights.append(attn_weight.detach())
         
     | 
| 157 | 
         
            +
                        else:
         
     | 
| 158 | 
         
            +
                            x = blk(x)
         
     | 
| 159 | 
         
            +
                        if i == n:
         
     | 
| 160 | 
         
            +
                            output = x.clone()
         
     | 
| 161 | 
         
            +
                            if self.return_transformer_qkv and return_qkv:
         
     | 
| 162 | 
         
            +
                                qkv_output = qkv.clone()
         
     | 
| 163 | 
         
            +
                            break
         
     | 
| 164 | 
         
            +
                    if self.return_transformer_qkv and return_qkv and return_att_weights:
         
     | 
| 165 | 
         
            +
                        return output, qkv_output, attn_weights
         
     | 
| 166 | 
         
            +
                    elif self.return_transformer_qkv and return_qkv:
         
     | 
| 167 | 
         
            +
                        return output, qkv_output
         
     | 
| 168 | 
         
            +
                    elif self.return_transformer_qkv and return_att_weights:
         
     | 
| 169 | 
         
            +
                        return output, attn_weights
         
     | 
| 170 | 
         
            +
                    else:
         
     | 
| 171 | 
         
            +
                        return output
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                def _intermediate_layers(
         
     | 
| 174 | 
         
            +
                        self,
         
     | 
| 175 | 
         
            +
                        x: torch.Tensor,
         
     | 
| 176 | 
         
            +
                        n: Union[int, Sequence] = 1,
         
     | 
| 177 | 
         
            +
                ):
         
     | 
| 178 | 
         
            +
                    outputs, num_blocks = [], len(self.blocks)
         
     | 
| 179 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 180 | 
         
            +
                        qkv_outputs = []
         
     | 
| 181 | 
         
            +
                    take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    # forward pass
         
     | 
| 184 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 185 | 
         
            +
                    x = self._pos_embed(x)
         
     | 
| 186 | 
         
            +
                    x = self.norm_pre(x)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 189 | 
         
            +
                        if self.return_transformer_qkv:
         
     | 
| 190 | 
         
            +
                            x, qkv = blk(x, return_qkv=True)
         
     | 
| 191 | 
         
            +
                        else:
         
     | 
| 192 | 
         
            +
                            x = blk(x)
         
     | 
| 193 | 
         
            +
                        if i in take_indices:
         
     | 
| 194 | 
         
            +
                            outputs.append(x)
         
     | 
| 195 | 
         
            +
                            if self.return_transformer_qkv:
         
     | 
| 196 | 
         
            +
                                qkv_outputs.append(qkv)
         
     | 
| 197 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 198 | 
         
            +
                        return outputs, qkv_outputs
         
     | 
| 199 | 
         
            +
                    else:
         
     | 
| 200 | 
         
            +
                        return outputs
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def get_intermediate_layers(
         
     | 
| 203 | 
         
            +
                        self,
         
     | 
| 204 | 
         
            +
                        x: torch.Tensor,
         
     | 
| 205 | 
         
            +
                        n: Union[int, Sequence] = 1,
         
     | 
| 206 | 
         
            +
                        reshape: bool = False,
         
     | 
| 207 | 
         
            +
                        return_prefix_tokens: bool = False,
         
     | 
| 208 | 
         
            +
                        norm: bool = False,
         
     | 
| 209 | 
         
            +
                ) -> tuple[tuple, Any]:
         
     | 
| 210 | 
         
            +
                    """ Intermediate layer accessor (NOTE: This is a WIP experiment).
         
     | 
| 211 | 
         
            +
                    Inspired by DINO / DINOv2 interface
         
     | 
| 212 | 
         
            +
                    """
         
     | 
| 213 | 
         
            +
                    # take last n blocks if n is an int, if in is a sequence, select by matching indices
         
     | 
| 214 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 215 | 
         
            +
                        outputs, qkv = self._intermediate_layers(x, n)
         
     | 
| 216 | 
         
            +
                    else:
         
     | 
| 217 | 
         
            +
                        outputs = self._intermediate_layers(x, n)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    if norm:
         
     | 
| 220 | 
         
            +
                        outputs = [self.norm(out) for out in outputs]
         
     | 
| 221 | 
         
            +
                    prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
         
     | 
| 222 | 
         
            +
                    outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    if reshape:
         
     | 
| 225 | 
         
            +
                        grid_size = self.patch_embed.grid_size
         
     | 
| 226 | 
         
            +
                        outputs = [
         
     | 
| 227 | 
         
            +
                            out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
         
     | 
| 228 | 
         
            +
                            for out in outputs
         
     | 
| 229 | 
         
            +
                        ]
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    if return_prefix_tokens:
         
     | 
| 232 | 
         
            +
                        return_out = tuple(zip(outputs, prefix_tokens))
         
     | 
| 233 | 
         
            +
                    else:
         
     | 
| 234 | 
         
            +
                        return_out = tuple(outputs)
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    if self.return_transformer_qkv:
         
     | 
| 237 | 
         
            +
                        return return_out, qkv
         
     | 
| 238 | 
         
            +
                    else:
         
     | 
| 239 | 
         
            +
                        return return_out
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -3,4 +3,8 @@ timm 
     | 
|
| 3 | 
         
             
            colorcet
         
     | 
| 4 | 
         
             
            matplotlib
         
     | 
| 5 | 
         
             
            torchvision
         
     | 
| 6 | 
         
            -
            streamlit
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 3 | 
         
             
            colorcet
         
     | 
| 4 | 
         
             
            matplotlib
         
     | 
| 5 | 
         
             
            torchvision
         
     | 
| 6 | 
         
            +
            streamlit
         
     | 
| 7 | 
         
            +
            numpy
         
     | 
| 8 | 
         
            +
            pillow
         
     | 
| 9 | 
         
            +
            scikit-image
         
     | 
| 10 | 
         
            +
            huggingface-hub
         
     | 
    	
        utils/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .data_utils import *
         
     | 
| 2 | 
         
            +
            from .visualize_att_maps import *
         
     | 
| 3 | 
         
            +
            from .misc_utils import *
         
     | 
| 4 | 
         
            +
            from .get_landmark_coordinates import *
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
    	
        utils/data_utils/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .dataset_utils import *
         
     | 
| 2 | 
         
            +
            from .reversible_affine_transform import *
         
     | 
| 3 | 
         
            +
            from .transform_utils import *
         
     | 
| 4 | 
         
            +
            from .class_balanced_distributed_sampler import *
         
     | 
| 5 | 
         
            +
            from .class_balanced_sampler import *
         
     | 
    	
        utils/data_utils/class_balanced_distributed_sampler.py
    ADDED
    
    | 
         @@ -0,0 +1,100 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 3 | 
         
            +
            from typing import Optional
         
     | 
| 4 | 
         
            +
            import math
         
     | 
| 5 | 
         
            +
            import torch.distributed as dist
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class ClassBalancedDistributedSampler(torch.utils.data.Sampler):
         
     | 
| 9 | 
         
            +
                """
         
     | 
| 10 | 
         
            +
                A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class
         
     | 
| 11 | 
         
            +
                Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None,
         
     | 
| 15 | 
         
            +
                             shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None:
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    if not shuffle:
         
     | 
| 18 | 
         
            +
                        raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler")
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    # Check if the dataset has a generate_class_balanced_indices method
         
     | 
| 21 | 
         
            +
                    if not hasattr(dataset, 'generate_class_balanced_indices'):
         
     | 
| 22 | 
         
            +
                        raise ValueError("Dataset does not have a generate_class_balanced_indices method")
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.shuffle = shuffle
         
     | 
| 25 | 
         
            +
                    self.seed = seed
         
     | 
| 26 | 
         
            +
                    if num_replicas is None:
         
     | 
| 27 | 
         
            +
                        if not dist.is_available():
         
     | 
| 28 | 
         
            +
                            raise RuntimeError("Requires distributed package to be available")
         
     | 
| 29 | 
         
            +
                        num_replicas = dist.get_world_size()
         
     | 
| 30 | 
         
            +
                    if rank is None:
         
     | 
| 31 | 
         
            +
                        if not dist.is_available():
         
     | 
| 32 | 
         
            +
                            raise RuntimeError("Requires distributed package to be available")
         
     | 
| 33 | 
         
            +
                        rank = dist.get_rank()
         
     | 
| 34 | 
         
            +
                    if rank >= num_replicas or rank < 0:
         
     | 
| 35 | 
         
            +
                        raise ValueError(
         
     | 
| 36 | 
         
            +
                            f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
         
     | 
| 37 | 
         
            +
                    self.dataset = dataset
         
     | 
| 38 | 
         
            +
                    self.num_replicas = num_replicas
         
     | 
| 39 | 
         
            +
                    self.rank = rank
         
     | 
| 40 | 
         
            +
                    self.epoch = 0
         
     | 
| 41 | 
         
            +
                    self.drop_last = drop_last
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    # Calculate the number of samples
         
     | 
| 44 | 
         
            +
                    g = torch.Generator()
         
     | 
| 45 | 
         
            +
                    g.manual_seed(self.seed + self.epoch)
         
     | 
| 46 | 
         
            +
                    self.num_samples_per_class = num_samples_per_class
         
     | 
| 47 | 
         
            +
                    indices = dataset.generate_class_balanced_indices(torch.Generator(),
         
     | 
| 48 | 
         
            +
                                                                      num_samples_per_class=num_samples_per_class)
         
     | 
| 49 | 
         
            +
                    dataset_size = len(indices)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    # If the dataset length is evenly divisible by # of replicas, then there
         
     | 
| 52 | 
         
            +
                    # is no need to drop any data, since the dataset will be split equally.
         
     | 
| 53 | 
         
            +
                    if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
         
     | 
| 54 | 
         
            +
                        # Split to nearest available length that is evenly divisible.
         
     | 
| 55 | 
         
            +
                        # This is to ensure each rank receives the same amount of data when
         
     | 
| 56 | 
         
            +
                        # using this Sampler.
         
     | 
| 57 | 
         
            +
                        self.num_samples = math.ceil(
         
     | 
| 58 | 
         
            +
                            (dataset_size - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
         
     | 
| 59 | 
         
            +
                        )
         
     | 
| 60 | 
         
            +
                    else:
         
     | 
| 61 | 
         
            +
                        self.num_samples = math.ceil(dataset_size / self.num_replicas)  # type: ignore[arg-type]
         
     | 
| 62 | 
         
            +
                    self.total_size = self.num_samples * self.num_replicas
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def __iter__(self):
         
     | 
| 65 | 
         
            +
                    # deterministically shuffle based on epoch and seed, here shuffle is assumed to be True
         
     | 
| 66 | 
         
            +
                    g = torch.Generator()
         
     | 
| 67 | 
         
            +
                    g.manual_seed(self.seed + self.epoch)
         
     | 
| 68 | 
         
            +
                    indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    if not self.drop_last:
         
     | 
| 71 | 
         
            +
                        # add extra samples to make it evenly divisible
         
     | 
| 72 | 
         
            +
                        padding_size = self.total_size - len(indices)
         
     | 
| 73 | 
         
            +
                        if padding_size <= len(indices):
         
     | 
| 74 | 
         
            +
                            indices += indices[:padding_size]
         
     | 
| 75 | 
         
            +
                        else:
         
     | 
| 76 | 
         
            +
                            indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
         
     | 
| 77 | 
         
            +
                    else:
         
     | 
| 78 | 
         
            +
                        # remove tail of data to make it evenly divisible.
         
     | 
| 79 | 
         
            +
                        indices = indices[:self.total_size]
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # subsample
         
     | 
| 82 | 
         
            +
                    indices = indices[self.rank:self.total_size:self.num_replicas]
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    return iter(indices)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                def __len__(self) -> int:
         
     | 
| 87 | 
         
            +
                    return self.num_samples
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def set_epoch(self, epoch: int) -> None:
         
     | 
| 90 | 
         
            +
                    r"""
         
     | 
| 91 | 
         
            +
                    Set the epoch for this sampler.
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    When :attr:`shuffle=True`, this ensures all replicas
         
     | 
| 94 | 
         
            +
                    use a different random ordering for each epoch. Otherwise, the next iteration of this
         
     | 
| 95 | 
         
            +
                    sampler will yield the same ordering.
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    Args:
         
     | 
| 98 | 
         
            +
                        epoch (int): Epoch number.
         
     | 
| 99 | 
         
            +
                    """
         
     | 
| 100 | 
         
            +
                    self.epoch = epoch
         
     | 
    	
        utils/data_utils/class_balanced_sampler.py
    ADDED
    
    | 
         @@ -0,0 +1,31 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class ClassBalancedRandomSampler(torch.utils.data.Sampler):
         
     | 
| 6 | 
         
            +
                """
         
     | 
| 7 | 
         
            +
                A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
         
     | 
| 8 | 
         
            +
                This is essentially the non-ddp version of ClassBalancedDistributedSampler
         
     | 
| 9 | 
         
            +
                Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
         
     | 
| 10 | 
         
            +
                """
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
         
     | 
| 13 | 
         
            +
                    self.dataset = dataset
         
     | 
| 14 | 
         
            +
                    self.seed = seed
         
     | 
| 15 | 
         
            +
                    # Calculate the number of samples
         
     | 
| 16 | 
         
            +
                    self.generator = torch.Generator()
         
     | 
| 17 | 
         
            +
                    self.generator.manual_seed(self.seed)
         
     | 
| 18 | 
         
            +
                    self.num_samples_per_class = num_samples_per_class
         
     | 
| 19 | 
         
            +
                    indices = dataset.generate_class_balanced_indices(self.generator,
         
     | 
| 20 | 
         
            +
                                                                      num_samples_per_class=num_samples_per_class)
         
     | 
| 21 | 
         
            +
                    self.num_samples = len(indices)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __iter__(self):
         
     | 
| 24 | 
         
            +
                    # Change seed for every function call
         
     | 
| 25 | 
         
            +
                    seed = int(torch.empty((), dtype=torch.int64).random_().item())
         
     | 
| 26 | 
         
            +
                    self.generator.manual_seed(seed)
         
     | 
| 27 | 
         
            +
                    indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
         
     | 
| 28 | 
         
            +
                    return iter(indices)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def __len__(self) -> int:
         
     | 
| 31 | 
         
            +
                    return self.num_samples
         
     | 
    	
        utils/data_utils/dataset_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,161 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from PIL import Image
         
     | 
| 2 | 
         
            +
            from torch import Tensor
         
     | 
| 3 | 
         
            +
            from typing import List, Optional
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torchvision
         
     | 
| 6 | 
         
            +
            import json
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def load_json(path: str):
         
     | 
| 10 | 
         
            +
                """
         
     | 
| 11 | 
         
            +
                Load json file from path and return the data
         
     | 
| 12 | 
         
            +
                :param path: Path to the json file
         
     | 
| 13 | 
         
            +
                :return:
         
     | 
| 14 | 
         
            +
                data: Data in the json file
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
                with open(path, 'r') as f:
         
     | 
| 17 | 
         
            +
                    data = json.load(f)
         
     | 
| 18 | 
         
            +
                return data
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def save_json(data: dict, path: str):
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
                Save data to a json file
         
     | 
| 24 | 
         
            +
                :param data: Data to be saved
         
     | 
| 25 | 
         
            +
                :param path: Path to save the data
         
     | 
| 26 | 
         
            +
                :return:
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
                with open(path, "w") as f:
         
     | 
| 29 | 
         
            +
                    json.dump(data, f)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def pil_loader(path):
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                Load image from path using PIL
         
     | 
| 35 | 
         
            +
                :param path: Path to the image
         
     | 
| 36 | 
         
            +
                :return:
         
     | 
| 37 | 
         
            +
                img: PIL Image
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
                with open(path, 'rb') as f:
         
     | 
| 40 | 
         
            +
                    img = Image.open(f)
         
     | 
| 41 | 
         
            +
                    return img.convert('RGB')
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def get_dimensions(image: Tensor):
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
                Get the dimensions of the image
         
     | 
| 47 | 
         
            +
                :param image: Tensor or PIL Image or np.ndarray
         
     | 
| 48 | 
         
            +
                :return:
         
     | 
| 49 | 
         
            +
                h: Height of the image
         
     | 
| 50 | 
         
            +
                w: Width of the image
         
     | 
| 51 | 
         
            +
                """
         
     | 
| 52 | 
         
            +
                if isinstance(image, Tensor):
         
     | 
| 53 | 
         
            +
                    _, h, w = image.shape
         
     | 
| 54 | 
         
            +
                elif isinstance(image, np.ndarray):
         
     | 
| 55 | 
         
            +
                    h, w, _ = image.shape
         
     | 
| 56 | 
         
            +
                elif isinstance(image, Image.Image):
         
     | 
| 57 | 
         
            +
                    w, h = image.size
         
     | 
| 58 | 
         
            +
                else:
         
     | 
| 59 | 
         
            +
                    raise ValueError(f"Invalid image type: {type(image)}")
         
     | 
| 60 | 
         
            +
                return h, w
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
         
     | 
| 64 | 
         
            +
                                      boxes: Optional[Tensor] = None, num_keypoints: int = 15):
         
     | 
| 65 | 
         
            +
                """
         
     | 
| 66 | 
         
            +
                Calculate the center crop parameters for the bounding boxes and landmarks and update them
         
     | 
| 67 | 
         
            +
                :param img: Image
         
     | 
| 68 | 
         
            +
                :param output_size: Output size of the cropped image
         
     | 
| 69 | 
         
            +
                :param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
         
     | 
| 70 | 
         
            +
                :param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
         
     | 
| 71 | 
         
            +
                :param num_keypoints: Number of keypoints
         
     | 
| 72 | 
         
            +
                :return:
         
     | 
| 73 | 
         
            +
                cropped_img: Center cropped image
         
     | 
| 74 | 
         
            +
                parts: Updated locations of the landmarks
         
     | 
| 75 | 
         
            +
                boxes: Updated bounding boxes of the landmarks
         
     | 
| 76 | 
         
            +
                """
         
     | 
| 77 | 
         
            +
                if isinstance(output_size, int):
         
     | 
| 78 | 
         
            +
                    output_size = (output_size, output_size)
         
     | 
| 79 | 
         
            +
                elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
         
     | 
| 80 | 
         
            +
                    output_size = (output_size[0], output_size[0])
         
     | 
| 81 | 
         
            +
                elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
         
     | 
| 82 | 
         
            +
                    output_size = output_size
         
     | 
| 83 | 
         
            +
                else:
         
     | 
| 84 | 
         
            +
                    raise ValueError(f"Invalid output size: {output_size}")
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                crop_height, crop_width = output_size
         
     | 
| 87 | 
         
            +
                image_height, image_width = get_dimensions(img)
         
     | 
| 88 | 
         
            +
                img = torchvision.transforms.functional.center_crop(img, output_size)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                if parts is not None:
         
     | 
| 93 | 
         
            +
                    for j in range(num_keypoints):
         
     | 
| 94 | 
         
            +
                        # Skip if part is invisible
         
     | 
| 95 | 
         
            +
                        if parts[j][-1] == 0:
         
     | 
| 96 | 
         
            +
                            continue
         
     | 
| 97 | 
         
            +
                        parts[j][1] -= crop_left
         
     | 
| 98 | 
         
            +
                        parts[j][2] -= crop_top
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                        # Skip if part is outside the crop
         
     | 
| 101 | 
         
            +
                        if parts[j][1] > crop_width or parts[j][2] > crop_height:
         
     | 
| 102 | 
         
            +
                            parts[j][-1] = 0
         
     | 
| 103 | 
         
            +
                        if parts[j][1] < 0 or parts[j][2] < 0:
         
     | 
| 104 | 
         
            +
                            parts[j][-1] = 0
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                        parts[j][1] = min(crop_width, parts[j][1])
         
     | 
| 107 | 
         
            +
                        parts[j][2] = min(crop_height, parts[j][2])
         
     | 
| 108 | 
         
            +
                        parts[j][1] = max(0, parts[j][1])
         
     | 
| 109 | 
         
            +
                        parts[j][2] = max(0, parts[j][2])
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                if boxes is not None:
         
     | 
| 112 | 
         
            +
                    boxes[1] -= crop_left
         
     | 
| 113 | 
         
            +
                    boxes[2] -= crop_top
         
     | 
| 114 | 
         
            +
                    boxes[1] = max(0, boxes[1])
         
     | 
| 115 | 
         
            +
                    boxes[2] = max(0, boxes[2])
         
     | 
| 116 | 
         
            +
                    boxes[1] = min(crop_width, boxes[1])
         
     | 
| 117 | 
         
            +
                    boxes[2] = min(crop_height, boxes[2])
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                return img, parts, boxes
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
         
     | 
| 123 | 
         
            +
                """
         
     | 
| 124 | 
         
            +
                Get the parameters for center cropping the image
         
     | 
| 125 | 
         
            +
                :param image_height: Height of the image
         
     | 
| 126 | 
         
            +
                :param image_width: Width of the image
         
     | 
| 127 | 
         
            +
                :param output_size: Output size of the cropped image
         
     | 
| 128 | 
         
            +
                :return:
         
     | 
| 129 | 
         
            +
                crop_top: Top coordinate of the cropped image
         
     | 
| 130 | 
         
            +
                crop_left: Left coordinate of the cropped image
         
     | 
| 131 | 
         
            +
                """
         
     | 
| 132 | 
         
            +
                if isinstance(output_size, int):
         
     | 
| 133 | 
         
            +
                    output_size = (output_size, output_size)
         
     | 
| 134 | 
         
            +
                elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
         
     | 
| 135 | 
         
            +
                    output_size = (output_size[0], output_size[0])
         
     | 
| 136 | 
         
            +
                elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
         
     | 
| 137 | 
         
            +
                    output_size = output_size
         
     | 
| 138 | 
         
            +
                else:
         
     | 
| 139 | 
         
            +
                    raise ValueError(f"Invalid output size: {output_size}")
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                crop_height, crop_width = output_size
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                if crop_width > image_width or crop_height > image_height:
         
     | 
| 144 | 
         
            +
                    padding_ltrb = [
         
     | 
| 145 | 
         
            +
                        (crop_width - image_width) // 2 if crop_width > image_width else 0,
         
     | 
| 146 | 
         
            +
                        (crop_height - image_height) // 2 if crop_height > image_height else 0,
         
     | 
| 147 | 
         
            +
                        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
         
     | 
| 148 | 
         
            +
                        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
         
     | 
| 149 | 
         
            +
                    ]
         
     | 
| 150 | 
         
            +
                    crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
         
     | 
| 151 | 
         
            +
                    return crop_top, crop_left
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                if crop_width == image_width and crop_height == image_height:
         
     | 
| 154 | 
         
            +
                    crop_top = 0
         
     | 
| 155 | 
         
            +
                    crop_left = 0
         
     | 
| 156 | 
         
            +
                    return crop_top, crop_left
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                crop_top = int(round((image_height - crop_height) / 2.0))
         
     | 
| 159 | 
         
            +
                crop_left = int(round((image_width - crop_width) / 2.0))
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                return crop_top, crop_left
         
     | 
    	
        utils/data_utils/reversible_affine_transform.py
    ADDED
    
    | 
         @@ -0,0 +1,82 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Description: This file contains the code for the reversible affine transform
         
     | 
| 2 | 
         
            +
            import torchvision.transforms as transforms
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from typing import List, Optional, Tuple, Any
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def generate_affine_trans_params(
         
     | 
| 8 | 
         
            +
                    degrees: List[float],
         
     | 
| 9 | 
         
            +
                    translate: Optional[List[float]],
         
     | 
| 10 | 
         
            +
                    scale_ranges: Optional[List[float]],
         
     | 
| 11 | 
         
            +
                    shears: Optional[List[float]],
         
     | 
| 12 | 
         
            +
                    img_size: List[int],
         
     | 
| 13 | 
         
            +
            ) -> Tuple[float, Tuple[int, int], float, Any]:
         
     | 
| 14 | 
         
            +
                """Get parameters for affine transformation
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Returns:
         
     | 
| 17 | 
         
            +
                    params to be passed to the affine transformation
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
                angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
         
     | 
| 20 | 
         
            +
                if translate is not None:
         
     | 
| 21 | 
         
            +
                    max_dx = float(translate[0] * img_size[0])
         
     | 
| 22 | 
         
            +
                    max_dy = float(translate[1] * img_size[1])
         
     | 
| 23 | 
         
            +
                    tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
         
     | 
| 24 | 
         
            +
                    ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
         
     | 
| 25 | 
         
            +
                    translations = (tx, ty)
         
     | 
| 26 | 
         
            +
                else:
         
     | 
| 27 | 
         
            +
                    translations = (0, 0)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                if scale_ranges is not None:
         
     | 
| 30 | 
         
            +
                    scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
         
     | 
| 31 | 
         
            +
                else:
         
     | 
| 32 | 
         
            +
                    scale = 1.0
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                shear_x = shear_y = 0.0
         
     | 
| 35 | 
         
            +
                if shears is not None:
         
     | 
| 36 | 
         
            +
                    shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
         
     | 
| 37 | 
         
            +
                    if len(shears) == 4:
         
     | 
| 38 | 
         
            +
                        shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                shear = (shear_x, shear_y)
         
     | 
| 41 | 
         
            +
                if shear_x == 0.0 and shear_y == 0.0:
         
     | 
| 42 | 
         
            +
                    shear = 0.0
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                return angle, translations, scale, shear
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def rigid_transform(img, angle, translate, scale, invert=False, shear=0,
         
     | 
| 48 | 
         
            +
                                interpolation=transforms.InterpolationMode.BILINEAR):
         
     | 
| 49 | 
         
            +
                """
         
     | 
| 50 | 
         
            +
                Affine transforms input image
         
     | 
| 51 | 
         
            +
                Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54
         
     | 
| 52 | 
         
            +
                Parameters
         
     | 
| 53 | 
         
            +
                ----------
         
     | 
| 54 | 
         
            +
                img: Tensor
         
     | 
| 55 | 
         
            +
                    Input image
         
     | 
| 56 | 
         
            +
                angle: int
         
     | 
| 57 | 
         
            +
                    Rotation angle between -180 and 180 degrees
         
     | 
| 58 | 
         
            +
                translate: [int]
         
     | 
| 59 | 
         
            +
                    Sequence of horizontal/vertical translations
         
     | 
| 60 | 
         
            +
                scale: float
         
     | 
| 61 | 
         
            +
                    How to scale the image
         
     | 
| 62 | 
         
            +
                invert: bool
         
     | 
| 63 | 
         
            +
                    Whether to invert the transformation
         
     | 
| 64 | 
         
            +
                shear: float
         
     | 
| 65 | 
         
            +
                    Shear angle in degrees
         
     | 
| 66 | 
         
            +
                interpolation: InterpolationMode
         
     | 
| 67 | 
         
            +
                    Interpolation mode to calculate output values
         
     | 
| 68 | 
         
            +
                Returns
         
     | 
| 69 | 
         
            +
                ----------
         
     | 
| 70 | 
         
            +
                img: Tensor
         
     | 
| 71 | 
         
            +
                    Transformed image
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                """
         
     | 
| 74 | 
         
            +
                if not invert:
         
     | 
| 75 | 
         
            +
                    img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
         
     | 
| 76 | 
         
            +
                                                       interpolation=interpolation)
         
     | 
| 77 | 
         
            +
                else:
         
     | 
| 78 | 
         
            +
                    translate = [-t for t in translate]
         
     | 
| 79 | 
         
            +
                    img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear)
         
     | 
| 80 | 
         
            +
                    img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                return img
         
     | 
    	
        utils/data_utils/transform_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,118 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torchvision import transforms as transforms
         
     | 
| 3 | 
         
            +
            from torchvision.transforms import Compose
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from timm.data.constants import \
         
     | 
| 6 | 
         
            +
                IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
         
     | 
| 7 | 
         
            +
            from timm.data import create_transform
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def make_train_transforms(args):
         
     | 
| 11 | 
         
            +
                train_transforms: Compose = transforms.Compose([
         
     | 
| 12 | 
         
            +
                    transforms.Resize(size=args.image_size, antialias=True),
         
     | 
| 13 | 
         
            +
                    transforms.RandomHorizontalFlip(p=args.hflip),
         
     | 
| 14 | 
         
            +
                    transforms.RandomVerticalFlip(p=args.vflip),
         
     | 
| 15 | 
         
            +
                    transforms.ColorJitter(),
         
     | 
| 16 | 
         
            +
                    transforms.RandomAffine(degrees=90, translate=(0.2, 0.2), scale=(0.8, 1.2)),
         
     | 
| 17 | 
         
            +
                    transforms.RandomCrop(args.image_size),
         
     | 
| 18 | 
         
            +
                    transforms.ToTensor(),
         
     | 
| 19 | 
         
            +
                    transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                ])
         
     | 
| 22 | 
         
            +
                return train_transforms
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def make_test_transforms(args):
         
     | 
| 26 | 
         
            +
                test_transforms: Compose = transforms.Compose([
         
     | 
| 27 | 
         
            +
                    transforms.Resize(size=args.image_size, antialias=True),
         
     | 
| 28 | 
         
            +
                    transforms.CenterCrop(args.image_size),
         
     | 
| 29 | 
         
            +
                    transforms.ToTensor(),
         
     | 
| 30 | 
         
            +
                    transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                ])
         
     | 
| 33 | 
         
            +
                return test_transforms
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def build_transform_timm(args, is_train=True):
         
     | 
| 37 | 
         
            +
                resize_im = args.image_size > 32
         
     | 
| 38 | 
         
            +
                imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
         
     | 
| 39 | 
         
            +
                mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
         
     | 
| 40 | 
         
            +
                std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                if is_train:
         
     | 
| 43 | 
         
            +
                    # this should always dispatch to transforms_imagenet_train
         
     | 
| 44 | 
         
            +
                    transform = create_transform(
         
     | 
| 45 | 
         
            +
                        input_size=args.image_size,
         
     | 
| 46 | 
         
            +
                        is_training=True,
         
     | 
| 47 | 
         
            +
                        color_jitter=args.color_jitter,
         
     | 
| 48 | 
         
            +
                        hflip=args.hflip,
         
     | 
| 49 | 
         
            +
                        vflip=args.vflip,
         
     | 
| 50 | 
         
            +
                        auto_augment=args.aa,
         
     | 
| 51 | 
         
            +
                        interpolation=args.train_interpolation,
         
     | 
| 52 | 
         
            +
                        re_prob=args.reprob,
         
     | 
| 53 | 
         
            +
                        re_mode=args.remode,
         
     | 
| 54 | 
         
            +
                        re_count=args.recount,
         
     | 
| 55 | 
         
            +
                        mean=mean,
         
     | 
| 56 | 
         
            +
                        std=std,
         
     | 
| 57 | 
         
            +
                    )
         
     | 
| 58 | 
         
            +
                    if not resize_im:
         
     | 
| 59 | 
         
            +
                        transform.transforms[0] = transforms.RandomCrop(
         
     | 
| 60 | 
         
            +
                            args.image_size, padding=4)
         
     | 
| 61 | 
         
            +
                    return transform
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                t = []
         
     | 
| 64 | 
         
            +
                if resize_im:
         
     | 
| 65 | 
         
            +
                    # warping (no cropping) when evaluated at 384 or larger
         
     | 
| 66 | 
         
            +
                    if args.image_size >= 384:
         
     | 
| 67 | 
         
            +
                        t.append(
         
     | 
| 68 | 
         
            +
                            transforms.Resize((args.image_size, args.image_size),
         
     | 
| 69 | 
         
            +
                                              interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
         
     | 
| 70 | 
         
            +
                        )
         
     | 
| 71 | 
         
            +
                        print(f"Warping {args.image_size} size input images...")
         
     | 
| 72 | 
         
            +
                    else:
         
     | 
| 73 | 
         
            +
                        if args.crop_pct is None:
         
     | 
| 74 | 
         
            +
                            args.crop_pct = 224 / 256
         
     | 
| 75 | 
         
            +
                        size = int(args.image_size / args.crop_pct)
         
     | 
| 76 | 
         
            +
                        t.append(
         
     | 
| 77 | 
         
            +
                            # to maintain same ratio w.r.t. 224 images
         
     | 
| 78 | 
         
            +
                            transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
         
     | 
| 79 | 
         
            +
                        )
         
     | 
| 80 | 
         
            +
                        t.append(transforms.CenterCrop(args.image_size))
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                t.append(transforms.ToTensor())
         
     | 
| 83 | 
         
            +
                t.append(transforms.Normalize(mean, std))
         
     | 
| 84 | 
         
            +
                return transforms.Compose(t)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
         
     | 
| 88 | 
         
            +
                mean = torch.as_tensor(mean)
         
     | 
| 89 | 
         
            +
                std = torch.as_tensor(std)
         
     | 
| 90 | 
         
            +
                un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
         
     | 
| 91 | 
         
            +
                return un_normalize
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            def normalize_only(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
         
     | 
| 95 | 
         
            +
                normalize = transforms.Normalize(mean=mean, std=std)
         
     | 
| 96 | 
         
            +
                return normalize
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
         
     | 
| 100 | 
         
            +
                                           resize_resolution=(256, 256)):
         
     | 
| 101 | 
         
            +
                mean = torch.as_tensor(mean)
         
     | 
| 102 | 
         
            +
                std = torch.as_tensor(std)
         
     | 
| 103 | 
         
            +
                resize_unnorm = transforms.Compose([
         
     | 
| 104 | 
         
            +
                    transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
         
     | 
| 105 | 
         
            +
                    transforms.Resize(size=resize_resolution, antialias=True)])
         
     | 
| 106 | 
         
            +
                return resize_unnorm
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            def load_transforms(args):
         
     | 
| 110 | 
         
            +
                # Get the transforms and load the dataset
         
     | 
| 111 | 
         
            +
                if args.augmentations_to_use == 'timm':
         
     | 
| 112 | 
         
            +
                    train_transforms = build_transform_timm(args, is_train=True)
         
     | 
| 113 | 
         
            +
                elif args.augmentations_to_use == 'cub_original':
         
     | 
| 114 | 
         
            +
                    train_transforms = make_train_transforms(args)
         
     | 
| 115 | 
         
            +
                else:
         
     | 
| 116 | 
         
            +
                    raise ValueError('Augmentations not supported.')
         
     | 
| 117 | 
         
            +
                test_transforms = make_test_transforms(args)
         
     | 
| 118 | 
         
            +
                return train_transforms, test_transforms
         
     | 
    	
        utils/get_landmark_coordinates.py
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # This file contains the function to generate the center coordinates as tensor for the current net.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def landmark_coordinates(maps, grid_x=None, grid_y=None):
         
     | 
| 6 | 
         
            +
                """
         
     | 
| 7 | 
         
            +
                Generate the center coordinates as tensor for the current net.
         
     | 
| 8 | 
         
            +
                Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19
         
     | 
| 9 | 
         
            +
                Parameters
         
     | 
| 10 | 
         
            +
                ----------
         
     | 
| 11 | 
         
            +
                maps: torch.Tensor
         
     | 
| 12 | 
         
            +
                    Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability
         
     | 
| 13 | 
         
            +
                grid_x: torch.Tensor
         
     | 
| 14 | 
         
            +
                    The grid x coordinates
         
     | 
| 15 | 
         
            +
                grid_y: torch.Tensor
         
     | 
| 16 | 
         
            +
                    The grid y coordinates
         
     | 
| 17 | 
         
            +
                Returns
         
     | 
| 18 | 
         
            +
                ----------
         
     | 
| 19 | 
         
            +
                loc_x: Tensor
         
     | 
| 20 | 
         
            +
                    The centroid x coordinates
         
     | 
| 21 | 
         
            +
                loc_y: Tensor
         
     | 
| 22 | 
         
            +
                    The centroid y coordinates
         
     | 
| 23 | 
         
            +
                grid_x: Tensor
         
     | 
| 24 | 
         
            +
                grid_y: Tensor
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                return_grid = False
         
     | 
| 27 | 
         
            +
                if grid_x is None or grid_y is None:
         
     | 
| 28 | 
         
            +
                    return_grid = True
         
     | 
| 29 | 
         
            +
                    grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
         
     | 
| 30 | 
         
            +
                                                    torch.arange(maps.shape[3]), indexing='ij')
         
     | 
| 31 | 
         
            +
                    grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
         
     | 
| 32 | 
         
            +
                    grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
         
     | 
| 33 | 
         
            +
                map_sums = maps.sum(3).sum(2).detach()
         
     | 
| 34 | 
         
            +
                maps_x = grid_x * maps
         
     | 
| 35 | 
         
            +
                maps_y = grid_y * maps
         
     | 
| 36 | 
         
            +
                loc_x = maps_x.sum(3).sum(2) / map_sums
         
     | 
| 37 | 
         
            +
                loc_y = maps_y.sum(3).sum(2) / map_sums
         
     | 
| 38 | 
         
            +
                if return_grid:
         
     | 
| 39 | 
         
            +
                    return loc_x, loc_y, grid_x, grid_y
         
     | 
| 40 | 
         
            +
                else:
         
     | 
| 41 | 
         
            +
                    return loc_x, loc_y
         
     | 
    	
        utils/misc_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,135 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            from functools import reduce
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            from pathlib import Path
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def factors(n):
         
     | 
| 11 | 
         
            +
                return reduce(list.__add__,
         
     | 
| 12 | 
         
            +
                              ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0))
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def file_line_count(filename: str) -> int:
         
     | 
| 16 | 
         
            +
                """Count the number of lines in a file"""
         
     | 
| 17 | 
         
            +
                with open(filename, 'rb') as f:
         
     | 
| 18 | 
         
            +
                    return sum(1 for _ in f)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def compute_attention(qkv, scale=None):
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
                Compute attention matrix (same as in the pytorch scaled dot product attention)
         
     | 
| 24 | 
         
            +
                Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
         
     | 
| 25 | 
         
            +
                :param qkv: Query, key and value tensors concatenated along the first dimension
         
     | 
| 26 | 
         
            +
                :param scale: Scale factor for the attention computation
         
     | 
| 27 | 
         
            +
                :return:
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                if isinstance(qkv, torch.Tensor):
         
     | 
| 30 | 
         
            +
                    query, key, value = qkv.unbind(0)
         
     | 
| 31 | 
         
            +
                else:
         
     | 
| 32 | 
         
            +
                    query, key, value = qkv
         
     | 
| 33 | 
         
            +
                scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
         
     | 
| 34 | 
         
            +
                L, S = query.size(-2), key.size(-2)
         
     | 
| 35 | 
         
            +
                attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
         
     | 
| 36 | 
         
            +
                attn_weight = query @ key.transpose(-2, -1) * scale_factor
         
     | 
| 37 | 
         
            +
                attn_weight += attn_bias
         
     | 
| 38 | 
         
            +
                attn_weight = torch.softmax(attn_weight, dim=-1)
         
     | 
| 39 | 
         
            +
                attn_out = attn_weight @ value
         
     | 
| 40 | 
         
            +
                return attn_weight, attn_out
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def compute_dot_product_similarity(a, b):
         
     | 
| 44 | 
         
            +
                scores = a @ b.transpose(-1, -2)
         
     | 
| 45 | 
         
            +
                return scores
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def compute_cross_entropy(p, q):
         
     | 
| 49 | 
         
            +
                q = torch.nn.functional.log_softmax(q, dim=-1)
         
     | 
| 50 | 
         
            +
                loss = torch.sum(p * q, dim=-1)
         
     | 
| 51 | 
         
            +
                return - loss.mean()
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")):
         
     | 
| 55 | 
         
            +
                """
         
     | 
| 56 | 
         
            +
                Perform attention rollout, 
         
     | 
| 57 | 
         
            +
                Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16
         
     | 
| 58 | 
         
            +
                Parameters
         
     | 
| 59 | 
         
            +
                ----------
         
     | 
| 60 | 
         
            +
                attentions : list
         
     | 
| 61 | 
         
            +
                    List of attention matrices, one for each transformer layer
         
     | 
| 62 | 
         
            +
                discard_ratio : float
         
     | 
| 63 | 
         
            +
                    Ratio of lowest attention values to discard
         
     | 
| 64 | 
         
            +
                head_fusion : str
         
     | 
| 65 | 
         
            +
                    Type of fusion to use for attention heads. One of "mean", "max", "min"
         
     | 
| 66 | 
         
            +
                device : torch.device
         
     | 
| 67 | 
         
            +
                    Device to use for computation
         
     | 
| 68 | 
         
            +
                Returns
         
     | 
| 69 | 
         
            +
                -------
         
     | 
| 70 | 
         
            +
                mask : np.ndarray
         
     | 
| 71 | 
         
            +
                    Mask of shape (width, width), where width is the square root of the number of patches
         
     | 
| 72 | 
         
            +
                """
         
     | 
| 73 | 
         
            +
                result = torch.eye(attentions[0].size(-1), device=device)
         
     | 
| 74 | 
         
            +
                attentions = [attention.to(device) for attention in attentions]
         
     | 
| 75 | 
         
            +
                with torch.no_grad():
         
     | 
| 76 | 
         
            +
                    for attention in attentions:
         
     | 
| 77 | 
         
            +
                        if head_fusion == "mean":
         
     | 
| 78 | 
         
            +
                            attention_heads_fused = attention.mean(axis=1)
         
     | 
| 79 | 
         
            +
                        elif head_fusion == "max":
         
     | 
| 80 | 
         
            +
                            attention_heads_fused = attention.max(axis=1).values
         
     | 
| 81 | 
         
            +
                        elif head_fusion == "min":
         
     | 
| 82 | 
         
            +
                            attention_heads_fused = attention.min(axis=1).values
         
     | 
| 83 | 
         
            +
                        else:
         
     | 
| 84 | 
         
            +
                            raise "Attention head fusion type Not supported"
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                        # Drop the lowest attentions, but
         
     | 
| 87 | 
         
            +
                        # don't drop the class token
         
     | 
| 88 | 
         
            +
                        flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
         
     | 
| 89 | 
         
            +
                        _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
         
     | 
| 90 | 
         
            +
                        indices = indices[indices != 0]
         
     | 
| 91 | 
         
            +
                        flat[0, indices] = 0
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                        I = torch.eye(attention_heads_fused.size(-1), device=device)
         
     | 
| 94 | 
         
            +
                        a = (attention_heads_fused + 1.0 * I) / 2
         
     | 
| 95 | 
         
            +
                        a = a / a.sum(dim=-1)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                        result = torch.matmul(a, result)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                # Normalize the result by max value in each row
         
     | 
| 100 | 
         
            +
                result = result / result.max(dim=-1, keepdim=True)[0]
         
     | 
| 101 | 
         
            +
                return result
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            def sync_bn_conversion(model: torch.nn.Module):
         
     | 
| 105 | 
         
            +
                """
         
     | 
| 106 | 
         
            +
                Convert BatchNorm to SyncBatchNorm (used for DDP)
         
     | 
| 107 | 
         
            +
                :param model: PyTorch model
         
     | 
| 108 | 
         
            +
                :return:
         
     | 
| 109 | 
         
            +
                model: PyTorch model with SyncBatchNorm layers
         
     | 
| 110 | 
         
            +
                """
         
     | 
| 111 | 
         
            +
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
         
     | 
| 112 | 
         
            +
                return model
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            def check_snapshot(args):
         
     | 
| 116 | 
         
            +
                """
         
     | 
| 117 | 
         
            +
                Create directory to save training checkpoints, otherwise load the existing checkpoint.
         
     | 
| 118 | 
         
            +
                Additionally, if it is an array training job, create a new directory for each training job.
         
     | 
| 119 | 
         
            +
                :param args: Arguments from the argument parser
         
     | 
| 120 | 
         
            +
                :return:
         
     | 
| 121 | 
         
            +
                """
         
     | 
| 122 | 
         
            +
                # Check if it is an array training job (i.e. training with multiple random seeds on the same settings)
         
     | 
| 123 | 
         
            +
                if args.array_training_job and not args.resume_training:
         
     | 
| 124 | 
         
            +
                    args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed))
         
     | 
| 125 | 
         
            +
                    if not os.path.exists(args.snapshot_dir):
         
     | 
| 126 | 
         
            +
                        save_dir = Path(args.snapshot_dir)
         
     | 
| 127 | 
         
            +
                        save_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 128 | 
         
            +
                else:
         
     | 
| 129 | 
         
            +
                    # Create directory to save training checkpoints, otherwise load the existing checkpoint
         
     | 
| 130 | 
         
            +
                    if not os.path.exists(args.snapshot_dir):
         
     | 
| 131 | 
         
            +
                        if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir:
         
     | 
| 132 | 
         
            +
                            save_dir = Path(args.snapshot_dir)
         
     | 
| 133 | 
         
            +
                            save_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 134 | 
         
            +
                        else:
         
     | 
| 135 | 
         
            +
                            raise ValueError('Snapshot checkpoint does not exist.')
         
     | 
    	
        utils/visualize_att_maps.py
    ADDED
    
    | 
         @@ -0,0 +1,135 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 2 | 
         
            +
            from mpl_toolkits.axes_grid1 import make_axes_locatable
         
     | 
| 3 | 
         
            +
            import colorcet as cc
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import skimage
         
     | 
| 6 | 
         
            +
            from pathlib import Path
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from utils.data_utils.transform_utils import inverse_normalize_w_resize
         
     | 
| 11 | 
         
            +
            from utils.misc_utils import factors
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # Define the colors to use for the attention maps
         
     | 
| 14 | 
         
            +
            colors = cc.glasbey_category10
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            class VisualizeAttentionMaps:
         
     | 
| 18 | 
         
            +
                def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="",
         
     | 
| 19 | 
         
            +
                             dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False,
         
     | 
| 20 | 
         
            +
                             plot_landmark_amaps=False):
         
     | 
| 21 | 
         
            +
                    """
         
     | 
| 22 | 
         
            +
                    Plot attention maps and optionally landmark centroids on images.
         
     | 
| 23 | 
         
            +
                    :param snapshot_dir: Directory to save the visualization results
         
     | 
| 24 | 
         
            +
                    :param save_resolution: Size of the images to save
         
     | 
| 25 | 
         
            +
                    :param alpha: The transparency of the attention maps
         
     | 
| 26 | 
         
            +
                    :param sub_path_test: The sub-path of the test dataset
         
     | 
| 27 | 
         
            +
                    :param dataset_name: The name of the dataset
         
     | 
| 28 | 
         
            +
                    :param bg_label: The background label index in the attention maps
         
     | 
| 29 | 
         
            +
                    :param batch_size: The batch size
         
     | 
| 30 | 
         
            +
                    :param num_parts: The number of parts in the attention maps
         
     | 
| 31 | 
         
            +
                    :param plot_ims_separately: Whether to plot the images separately
         
     | 
| 32 | 
         
            +
                    :param plot_landmark_amaps: Whether to plot the landmark attention maps
         
     | 
| 33 | 
         
            +
                    """
         
     | 
| 34 | 
         
            +
                    self.save_resolution = save_resolution
         
     | 
| 35 | 
         
            +
                    self.alpha = alpha
         
     | 
| 36 | 
         
            +
                    self.sub_path_test = sub_path_test
         
     | 
| 37 | 
         
            +
                    self.dataset_name = dataset_name
         
     | 
| 38 | 
         
            +
                    self.bg_label = bg_label
         
     | 
| 39 | 
         
            +
                    self.snapshot_dir = snapshot_dir
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
         
     | 
| 42 | 
         
            +
                    self.batch_size = batch_size
         
     | 
| 43 | 
         
            +
                    self.nrows = factors(self.batch_size)[-1]
         
     | 
| 44 | 
         
            +
                    self.ncols = factors(self.batch_size)[-2]
         
     | 
| 45 | 
         
            +
                    self.num_parts = num_parts
         
     | 
| 46 | 
         
            +
                    self.req_colors = colors[:num_parts]
         
     | 
| 47 | 
         
            +
                    self.plot_ims_separately = plot_ims_separately
         
     | 
| 48 | 
         
            +
                    self.plot_landmark_amaps = plot_landmark_amaps
         
     | 
| 49 | 
         
            +
                    if self.nrows == 1 and self.ncols == 1:
         
     | 
| 50 | 
         
            +
                        self.figs_size = (10, 10)
         
     | 
| 51 | 
         
            +
                    else:
         
     | 
| 52 | 
         
            +
                        self.figs_size = (self.ncols * 2, self.nrows * 2)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def recalculate_nrows_ncols(self):
         
     | 
| 55 | 
         
            +
                    self.nrows = factors(self.batch_size)[-1]
         
     | 
| 56 | 
         
            +
                    self.ncols = factors(self.batch_size)[-2]
         
     | 
| 57 | 
         
            +
                    if self.nrows == 1 and self.ncols == 1:
         
     | 
| 58 | 
         
            +
                        self.figs_size = (10, 10)
         
     | 
| 59 | 
         
            +
                    else:
         
     | 
| 60 | 
         
            +
                        self.figs_size = (self.ncols * 2, self.nrows * 2)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                @torch.no_grad()
         
     | 
| 63 | 
         
            +
                def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""):
         
     | 
| 64 | 
         
            +
                    """
         
     | 
| 65 | 
         
            +
                    Plot images, attention maps and landmark centroids.
         
     | 
| 66 | 
         
            +
                    Parameters
         
     | 
| 67 | 
         
            +
                    ----------
         
     | 
| 68 | 
         
            +
                    ims: Tensor, [batch_size, 3, width_im, height_im]
         
     | 
| 69 | 
         
            +
                        Input images on which to show the attention maps
         
     | 
| 70 | 
         
            +
                    maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
         
     | 
| 71 | 
         
            +
                        The attention maps to display
         
     | 
| 72 | 
         
            +
                    epoch: int
         
     | 
| 73 | 
         
            +
                        The epoch number
         
     | 
| 74 | 
         
            +
                    curr_iter: int
         
     | 
| 75 | 
         
            +
                        The current iteration number
         
     | 
| 76 | 
         
            +
                    extra_info: str
         
     | 
| 77 | 
         
            +
                        Any extra information to add to the file name
         
     | 
| 78 | 
         
            +
                    """
         
     | 
| 79 | 
         
            +
                    ims = self.resize_unnorm(ims)
         
     | 
| 80 | 
         
            +
                    if ims.shape[0] != self.batch_size:
         
     | 
| 81 | 
         
            +
                        self.batch_size = ims.shape[0]
         
     | 
| 82 | 
         
            +
                        self.recalculate_nrows_ncols()
         
     | 
| 83 | 
         
            +
                    fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
         
     | 
| 84 | 
         
            +
                    ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
         
     | 
| 85 | 
         
            +
                    map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
         
     | 
| 86 | 
         
            +
                                                                 mode='bilinear',
         
     | 
| 87 | 
         
            +
                                                                 align_corners=True).argmax(dim=1).cpu().numpy()
         
     | 
| 88 | 
         
            +
                    for i, ax in enumerate(axs.ravel()):
         
     | 
| 89 | 
         
            +
                        curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors,
         
     | 
| 90 | 
         
            +
                                                           bg_label=self.bg_label, alpha=self.alpha)
         
     | 
| 91 | 
         
            +
                        ax.imshow(curr_map)
         
     | 
| 92 | 
         
            +
                        ax.axis('off')
         
     | 
| 93 | 
         
            +
                    save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test))
         
     | 
| 94 | 
         
            +
                    save_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 95 | 
         
            +
                    save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
         
     | 
| 96 | 
         
            +
                    fig.tight_layout()
         
     | 
| 97 | 
         
            +
                    if self.snapshot_dir != "":
         
     | 
| 98 | 
         
            +
                        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
         
     | 
| 99 | 
         
            +
                    else:
         
     | 
| 100 | 
         
            +
                        plt.show()
         
     | 
| 101 | 
         
            +
                    plt.close('all')
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    if self.plot_ims_separately:
         
     | 
| 104 | 
         
            +
                        fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
         
     | 
| 105 | 
         
            +
                        for i, ax in enumerate(axs.ravel()):
         
     | 
| 106 | 
         
            +
                            ax.imshow(ims[i])
         
     | 
| 107 | 
         
            +
                            ax.axis('off')
         
     | 
| 108 | 
         
            +
                        save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg')
         
     | 
| 109 | 
         
            +
                        fig.tight_layout()
         
     | 
| 110 | 
         
            +
                        if self.snapshot_dir != "":
         
     | 
| 111 | 
         
            +
                            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
         
     | 
| 112 | 
         
            +
                        else:
         
     | 
| 113 | 
         
            +
                            plt.show()
         
     | 
| 114 | 
         
            +
                    plt.close('all')
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    if self.plot_landmark_amaps:
         
     | 
| 117 | 
         
            +
                        if self.batch_size > 1:
         
     | 
| 118 | 
         
            +
                            raise ValueError('Not implemented for batch size > 1')
         
     | 
| 119 | 
         
            +
                        for i in range(self.num_parts):
         
     | 
| 120 | 
         
            +
                            fig, ax = plt.subplots(1, 1, figsize=self.figs_size)
         
     | 
| 121 | 
         
            +
                            divider = make_axes_locatable(ax)
         
     | 
| 122 | 
         
            +
                            cax = divider.append_axes('right', size='5%', pad=0.05)
         
     | 
| 123 | 
         
            +
                            im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian')
         
     | 
| 124 | 
         
            +
                            fig.colorbar(im, cax=cax, orientation='vertical')
         
     | 
| 125 | 
         
            +
                            ax.axis('off')
         
     | 
| 126 | 
         
            +
                            save_path = os.path.join(save_dir,
         
     | 
| 127 | 
         
            +
                                                     f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
         
     | 
| 128 | 
         
            +
                            fig.tight_layout()
         
     | 
| 129 | 
         
            +
                            if self.snapshot_dir != "":
         
     | 
| 130 | 
         
            +
                                plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
         
     | 
| 131 | 
         
            +
                            else:
         
     | 
| 132 | 
         
            +
                                plt.show()
         
     | 
| 133 | 
         
            +
                            plt.close()
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    plt.close('all')
         
     |