smhh24 commited on
Commit
560b597
1 Parent(s): e26c454

Upload 90 files

Browse files

Add Initial files

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +104 -0
  2. configs/config_v1_cnvnxtl.json +24 -0
  3. configs/config_v1_vitl14.json +23 -0
  4. configs/config_v2_vitl14.json +32 -0
  5. configs/config_v2_vits14.json +32 -0
  6. unidepth/layers/__init__.py +22 -0
  7. unidepth/layers/__pycache__/__init__.cpython-311.pyc +0 -0
  8. unidepth/layers/__pycache__/activation.cpython-311.pyc +0 -0
  9. unidepth/layers/__pycache__/attention.cpython-311.pyc +0 -0
  10. unidepth/layers/__pycache__/convnext.cpython-311.pyc +0 -0
  11. unidepth/layers/__pycache__/layer_scale.cpython-311.pyc +0 -0
  12. unidepth/layers/__pycache__/mlp.cpython-311.pyc +0 -0
  13. unidepth/layers/__pycache__/nystrom_attention.cpython-311.pyc +0 -0
  14. unidepth/layers/__pycache__/positional_encoding.cpython-311.pyc +0 -0
  15. unidepth/layers/__pycache__/upsample.cpython-311.pyc +0 -0
  16. unidepth/layers/activation.py +15 -0
  17. unidepth/layers/attention.py +308 -0
  18. unidepth/layers/convnext.py +44 -0
  19. unidepth/layers/drop_path.py +25 -0
  20. unidepth/layers/layer_scale.py +17 -0
  21. unidepth/layers/mlp.py +35 -0
  22. unidepth/layers/nystrom_attention.py +74 -0
  23. unidepth/layers/positional_encoding.py +227 -0
  24. unidepth/layers/upsample.py +134 -0
  25. unidepth/models/__init__.py +7 -0
  26. unidepth/models/__pycache__/__init__.cpython-311.pyc +0 -0
  27. unidepth/models/__pycache__/encoder.cpython-311.pyc +0 -0
  28. unidepth/models/backbones/__init__.py +9 -0
  29. unidepth/models/backbones/__pycache__/__init__.cpython-311.pyc +0 -0
  30. unidepth/models/backbones/__pycache__/convnext.cpython-311.pyc +0 -0
  31. unidepth/models/backbones/__pycache__/convnext2.cpython-311.pyc +0 -0
  32. unidepth/models/backbones/__pycache__/dinov2.cpython-311.pyc +0 -0
  33. unidepth/models/backbones/convnext.py +580 -0
  34. unidepth/models/backbones/convnext2.py +288 -0
  35. unidepth/models/backbones/dinov2.py +455 -0
  36. unidepth/models/backbones/metadinov2/__init__.py +12 -0
  37. unidepth/models/backbones/metadinov2/__pycache__/__init__.cpython-311.pyc +0 -0
  38. unidepth/models/backbones/metadinov2/__pycache__/attention.cpython-311.pyc +0 -0
  39. unidepth/models/backbones/metadinov2/__pycache__/block.cpython-311.pyc +0 -0
  40. unidepth/models/backbones/metadinov2/__pycache__/dino_head.cpython-311.pyc +0 -0
  41. unidepth/models/backbones/metadinov2/__pycache__/drop_path.cpython-311.pyc +0 -0
  42. unidepth/models/backbones/metadinov2/__pycache__/layer_scale.cpython-311.pyc +0 -0
  43. unidepth/models/backbones/metadinov2/__pycache__/mlp.cpython-311.pyc +0 -0
  44. unidepth/models/backbones/metadinov2/__pycache__/patch_embed.cpython-311.pyc +0 -0
  45. unidepth/models/backbones/metadinov2/__pycache__/swiglu_ffn.cpython-311.pyc +0 -0
  46. unidepth/models/backbones/metadinov2/attention.py +84 -0
  47. unidepth/models/backbones/metadinov2/block.py +282 -0
  48. unidepth/models/backbones/metadinov2/dino_head.py +68 -0
  49. unidepth/models/backbones/metadinov2/drop_path.py +37 -0
  50. unidepth/models/backbones/metadinov2/layer_scale.py +28 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import json
6
+ from unidepth.models import UniDepthV2
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib
10
+ from PIL import Image
11
+
12
+
13
+ # Load model configurations and initialize model
14
+ def load_model(config_path, model_path, encoder, device):
15
+ with open(config_path) as f:
16
+ config = json.load(f)
17
+
18
+ model = UniDepthV2(config)
19
+ model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True)
20
+ model = model.to(device).eval()
21
+
22
+ return model
23
+
24
+ # Inference function
25
+ def depth_estimation(image, model_path, encoder='vits'):
26
+ try:
27
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+ config_path = 'configs/config_v2_vits14.json'
29
+
30
+ # Ensure model path exists or download if needed
31
+ if not os.path.exists(model_path):
32
+ return "Model checkpoint not found. Please upload a valid model path."
33
+
34
+ model = load_model(config_path, model_path, encoder, device)
35
+
36
+ # Preprocess image
37
+ rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W
38
+ predictions = model.infer(rgb)
39
+ depth = predictions["depth"].squeeze().to('cpu').numpy()
40
+
41
+ min_depth = depth.min()
42
+ max_depth = depth.max()
43
+
44
+ depth_normalized = (depth - min_depth) / (max_depth - min_depth)
45
+
46
+ # Apply colormap
47
+ cmap = matplotlib.colormaps.get_cmap('Spectral')
48
+ depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8)
49
+
50
+ # Create a figure and axis for the colorbar
51
+ fig, ax = plt.subplots(figsize=(6, 0.4))
52
+ fig.subplots_adjust(bottom=0.5)
53
+
54
+ # Create a colorbar
55
+ norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth)
56
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
57
+ sm.set_array([])
58
+ cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)')
59
+
60
+ # Save the colorbar to a BytesIO object
61
+ from io import BytesIO
62
+ buf = BytesIO()
63
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
64
+ plt.close(fig)
65
+ buf.seek(0)
66
+
67
+ # Open the colorbar image
68
+ colorbar_img = Image.open(buf)
69
+
70
+ # Create a new image with space for the colorbar
71
+ new_height = depth_color.shape[0] + colorbar_img.size[1]
72
+ new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255))
73
+
74
+ # Paste the depth image and colorbar
75
+ new_img.paste(Image.fromarray(depth_color), (0, 0))
76
+ new_img.paste(colorbar_img, (0, depth_color.shape[0]))
77
+
78
+ return new_img
79
+
80
+
81
+ except Exception as e:
82
+ return f"Error occurred: {str(e)}"
83
+
84
+ # Gradio Interface
85
+ def main():
86
+ iface = gr.Interface(
87
+ fn=depth_estimation,
88
+ inputs=[
89
+ gr.Image(type="numpy", label="Input Image"),
90
+ gr.Textbox(value='checkpoint/latest.pth', label='Model Path'),
91
+ gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
92
+ ],
93
+ outputs=[
94
+ gr.Image(type="pil", label="Predicted Depth")
95
+ ],
96
+ title="Depth Anything V2 Metric Depth Estimation",
97
+ description="Upload an image to get its estimated depth map using Depth Anything V2.",
98
+ )
99
+
100
+ iface.launch()
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
configs/config_v1_cnvnxtl.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generic": {
3
+ "seed": 13
4
+ },
5
+ "training": {
6
+ },
7
+ "data": {
8
+ "image_shape": [462, 616]
9
+ },
10
+ "model": {
11
+ "name": "UniDepthV1",
12
+ "num_heads": 8,
13
+ "expansion": 4,
14
+ "pixel_decoder": {
15
+ "hidden_dim": 512,
16
+ "depths": [3, 2, 1],
17
+ "dropout": 0.0
18
+ },
19
+ "pixel_encoder": {
20
+ "name": "convnext_large",
21
+ "pretrained": null
22
+ }
23
+ }
24
+ }
configs/config_v1_vitl14.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generic": {
3
+ "seed": 13
4
+ },
5
+ "training": {},
6
+ "data": {
7
+ "image_shape": [462, 616]
8
+ },
9
+ "model": {
10
+ "name": "UniDepthV1",
11
+ "num_heads": 8,
12
+ "expansion": 4,
13
+ "pixel_decoder": {
14
+ "hidden_dim": 512,
15
+ "depths": [3, 2, 1],
16
+ "dropout": 0.0
17
+ },
18
+ "pixel_encoder": {
19
+ "name": "dinov2_vitl14",
20
+ "pretrained": null
21
+ }
22
+ }
23
+ }
configs/config_v2_vitl14.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generic": {
3
+ "seed": 13,
4
+ "deterministic": true
5
+ },
6
+ "training": {},
7
+ "data": {
8
+ "image_shape": [420, 560],
9
+ "shape_constraints": {
10
+ "ratio_bounds": [0.66, 2.0],
11
+ "pixels_bounds": [1400, 2400],
12
+ "patch_size": 14
13
+ }
14
+ },
15
+ "model": {
16
+ "name": "UniDepthV2",
17
+ "num_heads": 8,
18
+ "expansion": 4,
19
+ "pixel_decoder": {
20
+ "hidden_dim": 512,
21
+ "depths": [6, 0, 0],
22
+ "dropout": 0.0
23
+ },
24
+ "pixel_encoder": {
25
+ "name": "dinov2_vitl14",
26
+ "pretrained": null,
27
+ "use_norm": true,
28
+ "stacking_fn": "last",
29
+ "output_idx": [21,22,23,24]
30
+ }
31
+ }
32
+ }
configs/config_v2_vits14.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generic": {
3
+ "seed": 13,
4
+ "deterministic": true
5
+ },
6
+ "training": {},
7
+ "data": {
8
+ "image_shape": [420, 560],
9
+ "shape_constraints": {
10
+ "ratio_bounds": [0.66, 2.0],
11
+ "pixels_bounds": [1400, 2400],
12
+ "patch_size": 14
13
+ }
14
+ },
15
+ "model": {
16
+ "name": "UniDepthV2",
17
+ "num_heads": 8,
18
+ "expansion": 4,
19
+ "pixel_decoder": {
20
+ "hidden_dim": 512,
21
+ "depths": [6, 0, 0],
22
+ "dropout": 0.0
23
+ },
24
+ "pixel_encoder": {
25
+ "name": "dinov2_vits14",
26
+ "pretrained": null,
27
+ "use_norm": true,
28
+ "stacking_fn": "last",
29
+ "output_idx": [9,10,11,12]
30
+ }
31
+ }
32
+ }
unidepth/layers/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .activation import GEGLU, SwiGLU
2
+ from .attention import AttentionBlock, AttentionDecoderBlock
3
+ from .convnext import CvnxtBlock
4
+ from .mlp import MLP
5
+ from .nystrom_attention import NystromBlock
6
+ from .positional_encoding import PositionEmbeddingSine
7
+ from .upsample import (ConvUpsample, ConvUpsampleShuffle,
8
+ ConvUpsampleShuffleResidual)
9
+
10
+ __all__ = [
11
+ "SwiGLU",
12
+ "GEGLU",
13
+ "CvnxtBlock",
14
+ "AttentionBlock",
15
+ "NystromBlock",
16
+ "PositionEmbeddingSine",
17
+ "ConvUpsample",
18
+ "MLP",
19
+ "ConvUpsampleShuffle",
20
+ "AttentionDecoderBlock",
21
+ "ConvUpsampleShuffleResidual",
22
+ ]
unidepth/layers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (845 Bytes). View file
 
unidepth/layers/__pycache__/activation.cpython-311.pyc ADDED
Binary file (1.38 kB). View file
 
unidepth/layers/__pycache__/attention.cpython-311.pyc ADDED
Binary file (14.5 kB). View file
 
unidepth/layers/__pycache__/convnext.cpython-311.pyc ADDED
Binary file (2.41 kB). View file
 
unidepth/layers/__pycache__/layer_scale.cpython-311.pyc ADDED
Binary file (1.53 kB). View file
 
unidepth/layers/__pycache__/mlp.cpython-311.pyc ADDED
Binary file (2.41 kB). View file
 
unidepth/layers/__pycache__/nystrom_attention.cpython-311.pyc ADDED
Binary file (3.57 kB). View file
 
unidepth/layers/__pycache__/positional_encoding.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
unidepth/layers/__pycache__/upsample.cpython-311.pyc ADDED
Binary file (6.19 kB). View file
 
unidepth/layers/activation.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SwiGLU(nn.Module):
7
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
8
+ x, gates = x.chunk(2, dim=-1)
9
+ return x * F.silu(gates)
10
+
11
+
12
+ class GEGLU(nn.Module):
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ x, gates = x.chunk(2, dim=-1)
15
+ return x * F.gelu(gates)
unidepth/layers/attention.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ from .layer_scale import LayerScale
14
+ from .mlp import MLP
15
+
16
+
17
+ class SimpleAttention(nn.Module):
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ num_heads: int = 4,
22
+ dropout: float = 0.0,
23
+ cosine: bool = False,
24
+ context_dim: int | None = None,
25
+ ):
26
+ super().__init__()
27
+ self.dropout = dropout
28
+ self.num_heads = num_heads
29
+ self.hidden_dim = dim
30
+ context_dim = context_dim or dim
31
+
32
+ self.kv = nn.Linear(context_dim, dim * 2, bias=False)
33
+ self.q = nn.Linear(dim, dim, bias=False)
34
+ self.norm_attnx = nn.LayerNorm(dim)
35
+ self.norm_attnctx = nn.LayerNorm(context_dim)
36
+ self.cosine = cosine
37
+ self.out = nn.Linear(dim, dim)
38
+
39
+ def forward(
40
+ self,
41
+ x: torch.Tensor,
42
+ attn_bias: torch.Tensor | None = None,
43
+ context: torch.Tensor | None = None,
44
+ pos_embed: torch.Tensor | None = None,
45
+ pos_embed_context: torch.Tensor | None = None,
46
+ rope: nn.Module | None = None,
47
+ ) -> torch.Tensor:
48
+ context = x if context is None else context
49
+ x = self.norm_attnx(x)
50
+ context = self.norm_attnctx(context)
51
+ k, v = rearrange(
52
+ self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
53
+ ).unbind(dim=-1)
54
+ q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
55
+
56
+ if rope is not None:
57
+ q = rope(q)
58
+ k = rope(k)
59
+ else:
60
+ if pos_embed is not None:
61
+ pos_embed = rearrange(
62
+ pos_embed, "b n (h d) -> b h n d", h=self.num_heads
63
+ )
64
+ q = q + pos_embed
65
+ if pos_embed_context is not None:
66
+ pos_embed_context = rearrange(
67
+ pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
68
+ )
69
+ k = k + pos_embed_context
70
+
71
+ if self.cosine:
72
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
73
+ x = F.scaled_dot_product_attention(
74
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
75
+ )
76
+ x = rearrange(x, "b h n d -> b n (h d)")
77
+ x = self.out(x)
78
+ return x
79
+
80
+
81
+ class AttentionBlock(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim: int,
85
+ num_heads: int = 4,
86
+ expansion: int = 4,
87
+ dropout: float = 0.0,
88
+ cosine: bool = False,
89
+ gated: bool = False,
90
+ layer_scale: float = 1.0,
91
+ context_dim: int | None = None,
92
+ ):
93
+ super().__init__()
94
+ self.dropout = dropout
95
+ self.num_heads = num_heads
96
+ self.hidden_dim = dim
97
+ context_dim = context_dim or dim
98
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
99
+ self.kv = nn.Linear(context_dim, dim * 2)
100
+ self.q = nn.Linear(dim, dim)
101
+ self.norm_attnx = nn.LayerNorm(dim)
102
+ self.norm_attnctx = nn.LayerNorm(context_dim)
103
+ self.cosine = cosine
104
+ self.out = nn.Linear(dim, dim)
105
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
106
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
107
+
108
+ def attn(
109
+ self,
110
+ x: torch.Tensor,
111
+ attn_bias: torch.Tensor | None = None,
112
+ context: torch.Tensor | None = None,
113
+ pos_embed: torch.Tensor | None = None,
114
+ pos_embed_context: torch.Tensor | None = None,
115
+ rope: nn.Module | None = None,
116
+ ) -> torch.Tensor:
117
+ x = self.norm_attnx(x)
118
+ context = self.norm_attnctx(context)
119
+ k, v = rearrange(
120
+ self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
121
+ ).unbind(dim=-1)
122
+ q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
123
+
124
+ if rope is not None:
125
+ q = rope(q)
126
+ k = rope(k)
127
+ else:
128
+ if pos_embed is not None:
129
+ pos_embed = rearrange(
130
+ pos_embed, "b n (h d) -> b h n d", h=self.num_heads
131
+ )
132
+ q = q + pos_embed
133
+ if pos_embed_context is not None:
134
+ pos_embed_context = rearrange(
135
+ pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
136
+ )
137
+ k = k + pos_embed_context
138
+
139
+ if self.cosine:
140
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
141
+
142
+ x = F.scaled_dot_product_attention(
143
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
144
+ )
145
+ x = rearrange(x, "b h n d -> b n (h d)")
146
+ x = self.out(x)
147
+ return x
148
+
149
+ def forward(
150
+ self,
151
+ x: torch.Tensor,
152
+ attn_bias: torch.Tensor | None = None,
153
+ context: torch.Tensor | None = None,
154
+ pos_embed: torch.Tensor | None = None,
155
+ pos_embed_context: torch.Tensor | None = None,
156
+ rope: nn.Module | None = None,
157
+ ) -> torch.Tensor:
158
+ context = x if context is None else context
159
+ x = (
160
+ self.ls1(
161
+ self.attn(
162
+ x,
163
+ rope=rope,
164
+ attn_bias=attn_bias,
165
+ context=context,
166
+ pos_embed=pos_embed,
167
+ pos_embed_context=pos_embed_context,
168
+ )
169
+ )
170
+ + x
171
+ )
172
+ x = self.ls2(self.mlp(x)) + x
173
+ return x
174
+
175
+
176
+ class AttentionDecoderBlock(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim: int,
180
+ num_heads: int = 4,
181
+ expansion: int = 4,
182
+ dropout: float = 0.0,
183
+ cosine: bool = False,
184
+ gated: bool = False,
185
+ layer_scale: float = 1.0,
186
+ context_dim: int | None = None,
187
+ single_head_ca: bool = True,
188
+ ):
189
+ super().__init__()
190
+ self.dropout = dropout
191
+ self.num_heads = num_heads
192
+ self.hidden_dim = dim
193
+ self.single_head_ca = single_head_ca
194
+ context_dim = context_dim or dim
195
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
196
+ self.kv_ca = nn.Linear(context_dim, dim * 2)
197
+ self.q_ca = nn.Linear(dim, dim)
198
+ self.kv_sa = nn.Linear(dim, dim * 2)
199
+ self.q_sa = nn.Linear(dim, dim)
200
+ self.norm_x_sa = nn.LayerNorm(dim)
201
+ self.norm_x_ca = nn.LayerNorm(dim)
202
+ self.norm_ctx_ca = nn.LayerNorm(context_dim)
203
+ self.cosine = cosine
204
+ self.out_ca = nn.Linear(dim, dim)
205
+ self.out_sa = nn.Linear(dim, dim)
206
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
207
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
208
+ self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
209
+
210
+ def cross_attn(
211
+ self,
212
+ x: torch.Tensor,
213
+ attn_bias: torch.Tensor | None = None,
214
+ context: torch.Tensor | None = None,
215
+ pos_embed: torch.Tensor | None = None,
216
+ pos_embed_context: torch.Tensor | None = None,
217
+ rope: nn.Module | None = None,
218
+ ) -> torch.Tensor:
219
+ num_heads = 1 if self.single_head_ca else self.num_heads
220
+ x = self.norm_x_ca(x)
221
+ context = self.norm_ctx_ca(context)
222
+ k, v = rearrange(
223
+ self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2
224
+ ).unbind(dim=-1)
225
+ q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads)
226
+
227
+ if rope is not None:
228
+ q = rope(q)
229
+ k = rope(k)
230
+ else:
231
+ if pos_embed is not None:
232
+ pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads)
233
+ q = q + pos_embed
234
+ if pos_embed_context is not None:
235
+ pos_embed_context = rearrange(
236
+ pos_embed_context, "b n (h d) -> b h n d", h=num_heads
237
+ )
238
+ k = k + pos_embed_context
239
+
240
+ if self.cosine:
241
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
242
+ x = F.scaled_dot_product_attention(
243
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
244
+ )
245
+ x = rearrange(x, "b h n d -> b n (h d)")
246
+ x = self.out_ca(x)
247
+ return x
248
+
249
+ def self_attn(
250
+ self,
251
+ x: torch.Tensor,
252
+ attn_bias: torch.Tensor | None = None,
253
+ pos_embed: torch.Tensor | None = None,
254
+ rope: nn.Module | None = None,
255
+ ) -> torch.Tensor:
256
+ x = self.norm_x_sa(x)
257
+ k, v = rearrange(
258
+ self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
259
+ ).unbind(dim=-1)
260
+ q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads)
261
+
262
+ if rope is not None:
263
+ q = rope(q)
264
+ k = rope(k)
265
+ elif pos_embed is not None:
266
+ pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
267
+ q = q + pos_embed
268
+
269
+ if self.cosine:
270
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
271
+ x = F.scaled_dot_product_attention(
272
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
273
+ )
274
+ x = rearrange(x, "b h n d -> b n (h d)")
275
+ x = self.out_sa(x)
276
+ return x
277
+
278
+ def forward(
279
+ self,
280
+ x: torch.Tensor,
281
+ attn_bias: torch.Tensor | None = None,
282
+ context: torch.Tensor | None = None,
283
+ pos_embed: torch.Tensor | None = None,
284
+ pos_embed_context: torch.Tensor | None = None,
285
+ rope: nn.Module | None = None,
286
+ ) -> torch.Tensor:
287
+ context = x if context is None else context
288
+ x = (
289
+ self.ls1(
290
+ self.cross_attn(
291
+ x,
292
+ rope=rope,
293
+ attn_bias=attn_bias,
294
+ context=context,
295
+ pos_embed=pos_embed,
296
+ pos_embed_context=pos_embed_context,
297
+ )
298
+ )
299
+ + x
300
+ )
301
+ x = (
302
+ self.ls2(
303
+ self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed)
304
+ )
305
+ + x
306
+ )
307
+ x = self.ls3(self.mlp(x)) + x
308
+ return x
unidepth/layers/convnext.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class CvnxtBlock(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim,
9
+ kernel_size=7,
10
+ layer_scale=1.0,
11
+ expansion=4,
12
+ dilation=1,
13
+ padding_mode: str = "zeros",
14
+ ):
15
+ super().__init__()
16
+ self.dwconv = nn.Conv2d(
17
+ dim,
18
+ dim,
19
+ kernel_size=kernel_size,
20
+ padding=dilation * (kernel_size - 1) // 2,
21
+ groups=dim,
22
+ dilation=dilation,
23
+ padding_mode=padding_mode,
24
+ ) # depthwise conv
25
+ self.norm = nn.LayerNorm(dim)
26
+ self.pwconv1 = nn.Linear(dim, expansion * dim)
27
+ self.act = nn.GELU()
28
+ self.pwconv2 = nn.Linear(expansion * dim, dim)
29
+ self.gamma = (
30
+ nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0
31
+ )
32
+
33
+ def forward(self, x):
34
+ input = x
35
+ x = self.dwconv(x)
36
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
37
+ x = self.norm(x)
38
+ x = self.pwconv1(x)
39
+ x = self.act(x)
40
+ x = self.pwconv2(x)
41
+
42
+ x = self.gamma * x
43
+ x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
44
+ return x
unidepth/layers/drop_path.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False):
6
+ if drop_prob == 0.0 or not training:
7
+ return x
8
+ keep_prob = 1 - drop_prob
9
+ shape = (x.shape[0],) + (1,) * (
10
+ x.ndim - 1
11
+ ) # work with diff dim tensors, not just 2D ConvNets
12
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
13
+ if keep_prob > 0.0:
14
+ random_tensor.div_(keep_prob)
15
+ output = x * random_tensor
16
+ return output
17
+
18
+
19
+ class DropPath(nn.Module):
20
+ def __init__(self, drop_prob=None):
21
+ super(DropPath, self).__init__()
22
+ self.drop_prob = drop_prob
23
+
24
+ def forward(self, x):
25
+ return drop_path(x, self.drop_prob, self.training)
unidepth/layers/layer_scale.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerScale(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ init_values: float | torch.Tensor = 1e-5,
10
+ inplace: bool = False,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.inplace = inplace
14
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
unidepth/layers/mlp.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from unidepth.utils.misc import default
5
+
6
+ from .activation import SwiGLU
7
+
8
+
9
+ class MLP(nn.Module):
10
+ def __init__(
11
+ self,
12
+ input_dim: int,
13
+ expansion: int = 4,
14
+ dropout: float = 0.0,
15
+ gated: bool = False,
16
+ output_dim: int | None = None,
17
+ ):
18
+ super().__init__()
19
+ if gated:
20
+ expansion = int(expansion * 2 / 3)
21
+ hidden_dim = int(input_dim * expansion)
22
+ output_dim = default(output_dim, input_dim)
23
+ self.norm = nn.LayerNorm(input_dim)
24
+ self.proj1 = nn.Linear(input_dim, hidden_dim)
25
+ self.proj2 = nn.Linear(hidden_dim, output_dim)
26
+ self.act = nn.GELU() if not gated else SwiGLU()
27
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ x = self.norm(x)
31
+ x = self.proj1(x)
32
+ x = self.act(x)
33
+ x = self.proj2(x)
34
+ x = self.dropout(x)
35
+ return x
unidepth/layers/nystrom_attention.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from xformers.components.attention import NystromAttention
8
+
9
+ from .attention import AttentionBlock
10
+
11
+
12
+ class NystromBlock(AttentionBlock):
13
+ def __init__(
14
+ self,
15
+ dim: int,
16
+ num_heads: int = 4,
17
+ expansion: int = 4,
18
+ dropout: float = 0.0,
19
+ cosine: bool = False,
20
+ gated: bool = False,
21
+ layer_scale: float = 1.0,
22
+ context_dim: int | None = None,
23
+ ):
24
+ super().__init__(
25
+ dim=dim,
26
+ num_heads=num_heads,
27
+ expansion=expansion,
28
+ dropout=dropout,
29
+ cosine=cosine,
30
+ gated=gated,
31
+ layer_scale=layer_scale,
32
+ context_dim=context_dim,
33
+ )
34
+ self.attention_fn = NystromAttention(
35
+ num_landmarks=128, num_heads=num_heads, dropout=dropout
36
+ )
37
+
38
+ def attn(
39
+ self,
40
+ x: torch.Tensor,
41
+ attn_bias: torch.Tensor | None = None,
42
+ context: torch.Tensor | None = None,
43
+ pos_embed: torch.Tensor | None = None,
44
+ pos_embed_context: torch.Tensor | None = None,
45
+ rope: nn.Module | None = None,
46
+ ) -> torch.Tensor:
47
+ x = self.norm_attnx(x)
48
+ context = self.norm_attnctx(context)
49
+ k, v = rearrange(
50
+ self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2
51
+ ).unbind(dim=-1)
52
+ q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads)
53
+
54
+ if rope is not None:
55
+ q = rope(q)
56
+ k = rope(k)
57
+ else:
58
+ if pos_embed is not None:
59
+ pos_embed = rearrange(
60
+ pos_embed, "b n (h d) -> b n h d", h=self.num_heads
61
+ )
62
+ q = q + pos_embed
63
+ if pos_embed_context is not None:
64
+ pos_embed_context = rearrange(
65
+ pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads
66
+ )
67
+ k = k + pos_embed_context
68
+
69
+ if self.cosine:
70
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
71
+ x = self.attention_fn(q, k, v, key_padding_mask=attn_bias)
72
+ x = rearrange(x, "b n h d -> b n (h d)")
73
+ x = self.out(x)
74
+ return x
unidepth/layers/positional_encoding.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from math import pi
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange, repeat
12
+
13
+
14
+ class PositionEmbeddingSine(nn.Module):
15
+ def __init__(
16
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
17
+ ):
18
+ super().__init__()
19
+ self.num_pos_feats = num_pos_feats
20
+ self.temperature = temperature
21
+ self.normalize = normalize
22
+ if scale is not None and normalize is False:
23
+ raise ValueError("normalize should be True if scale is passed")
24
+ if scale is None:
25
+ scale = 2 * pi
26
+ self.scale = scale
27
+
28
+ def forward(
29
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
30
+ ) -> torch.Tensor:
31
+ if mask is None:
32
+ mask = torch.zeros(
33
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
34
+ )
35
+ not_mask = ~mask
36
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
37
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
38
+ if self.normalize:
39
+ eps = 1e-6
40
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
41
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
42
+
43
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
44
+ dim_t = self.temperature ** (
45
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
46
+ )
47
+
48
+ pos_x = x_embed[:, :, :, None] / dim_t
49
+ pos_y = y_embed[:, :, :, None] / dim_t
50
+ pos_x = torch.stack(
51
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
52
+ ).flatten(3)
53
+ pos_y = torch.stack(
54
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
55
+ ).flatten(3)
56
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
57
+ return pos
58
+
59
+ def __repr__(self, _repr_indent=4):
60
+ head = "Positional encoding " + self.__class__.__name__
61
+ body = [
62
+ "num_pos_feats: {}".format(self.num_pos_feats),
63
+ "temperature: {}".format(self.temperature),
64
+ "normalize: {}".format(self.normalize),
65
+ "scale: {}".format(self.scale),
66
+ ]
67
+ # _repr_indent = 4
68
+ lines = [head] + [" " * _repr_indent + line for line in body]
69
+ return "\n".join(lines)
70
+
71
+
72
+ class LearnedSinusoidalPosEmb(nn.Module):
73
+ def __init__(self, dim):
74
+ super().__init__()
75
+ assert (dim % 2) == 0
76
+ half_dim = dim // 2
77
+ self.weights = nn.Parameter(torch.randn(half_dim))
78
+
79
+ def forward(self, x):
80
+ x = rearrange(x, "b -> b 1")
81
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
82
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
83
+ fouriered = torch.cat((x, fouriered), dim=-1)
84
+ return fouriered
85
+
86
+
87
+ def generate_fourier_features(x, max_freq=64, num_bands=16):
88
+ x = x.unsqueeze(-1)
89
+ device, dtype, orig_x = x.device, x.dtype, x
90
+
91
+ scales = torch.linspace(
92
+ -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype
93
+ )
94
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
95
+
96
+ x = x * scales * pi
97
+ x = torch.cat([x.sin(), x.cos()], dim=-1)
98
+ x = torch.cat((x, orig_x), dim=-1)
99
+ return x.flatten(-2)
100
+
101
+
102
+ def broadcat(tensors, dim=-1):
103
+ num_tensors = len(tensors)
104
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
105
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
106
+ shape_len = list(shape_lens)[0]
107
+ dim = (dim + shape_len) if dim < 0 else dim
108
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
109
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
110
+ assert all(
111
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
112
+ ), "invalid dimensions for broadcastable concatentation"
113
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
114
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
115
+ expanded_dims.insert(dim, (dim, dims[dim]))
116
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
117
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
118
+ return torch.cat(tensors, dim=dim)
119
+
120
+
121
+ def rotate_half(x):
122
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
123
+ x1, x2 = x.unbind(dim=-1)
124
+ x = torch.stack((-x2, x1), dim=-1)
125
+ return rearrange(x, "... d r -> ... (d r)")
126
+
127
+
128
+ class VisionRotaryEmbedding(nn.Module):
129
+ def __init__(
130
+ self,
131
+ dim,
132
+ pt_seq_len,
133
+ ft_seq_len=None,
134
+ custom_freqs=None,
135
+ freqs_for="lang",
136
+ theta=10000,
137
+ max_freq=10,
138
+ num_freqs=1,
139
+ ):
140
+ super().__init__()
141
+ if custom_freqs:
142
+ freqs = custom_freqs
143
+ elif freqs_for == "lang":
144
+ freqs = 1.0 / (
145
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
146
+ )
147
+ elif freqs_for == "pixel":
148
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
149
+ elif freqs_for == "constant":
150
+ freqs = torch.ones(num_freqs).float()
151
+ else:
152
+ raise ValueError(f"unknown modality {freqs_for}")
153
+
154
+ if ft_seq_len is None:
155
+ ft_seq_len = pt_seq_len
156
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
157
+
158
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
159
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
160
+
161
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
162
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
163
+
164
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
165
+
166
+ self.register_buffer("freqs_cos", freqs.cos())
167
+ self.register_buffer("freqs_sin", freqs.sin())
168
+
169
+ print("======== shape of rope freq", self.freqs_cos.shape, "========")
170
+
171
+ def forward(self, t, start_index=0):
172
+ rot_dim = self.freqs_cos.shape[-1]
173
+ end_index = start_index + rot_dim
174
+ assert (
175
+ rot_dim <= t.shape[-1]
176
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
177
+ t_left, t, t_right = (
178
+ t[..., :start_index],
179
+ t[..., start_index:end_index],
180
+ t[..., end_index:],
181
+ )
182
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
183
+ return torch.cat((t_left, t, t_right), dim=-1)
184
+
185
+
186
+ class VisionRotaryEmbeddingFast(nn.Module):
187
+ def __init__(
188
+ self,
189
+ dim,
190
+ pt_seq_len,
191
+ ft_seq_len=None,
192
+ custom_freqs=None,
193
+ freqs_for="lang",
194
+ theta=10000,
195
+ max_freq=10,
196
+ num_freqs=1,
197
+ ):
198
+ super().__init__()
199
+ if custom_freqs:
200
+ freqs = custom_freqs
201
+ elif freqs_for == "lang":
202
+ freqs = 1.0 / (
203
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
204
+ )
205
+ elif freqs_for == "pixel":
206
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
207
+ elif freqs_for == "constant":
208
+ freqs = torch.ones(num_freqs).float()
209
+ else:
210
+ raise ValueError(f"unknown modality {freqs_for}")
211
+
212
+ if ft_seq_len is None:
213
+ ft_seq_len = pt_seq_len
214
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
215
+
216
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
217
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
218
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
219
+
220
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
221
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
222
+
223
+ self.register_buffer("freqs_cos", freqs_cos)
224
+ self.register_buffer("freqs_sin", freqs_sin)
225
+
226
+ def forward(self, t):
227
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
unidepth/layers/upsample.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+
10
+ from .convnext import CvnxtBlock
11
+
12
+
13
+ class ConvUpsample(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_dim,
17
+ num_layers: int = 2,
18
+ expansion: int = 4,
19
+ layer_scale: float = 1.0,
20
+ kernel_size: int = 7,
21
+ **kwargs,
22
+ ):
23
+ super().__init__()
24
+ self.convs = nn.ModuleList([])
25
+ for _ in range(num_layers):
26
+ self.convs.append(
27
+ CvnxtBlock(
28
+ hidden_dim,
29
+ kernel_size=kernel_size,
30
+ expansion=expansion,
31
+ layer_scale=layer_scale,
32
+ )
33
+ )
34
+ self.up = nn.Sequential(
35
+ nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
36
+ nn.UpsamplingBilinear2d(scale_factor=2),
37
+ nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1),
38
+ )
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ for conv in self.convs:
42
+ x = conv(x)
43
+ x = self.up(x)
44
+ x = rearrange(x, "b c h w -> b (h w) c")
45
+ return x
46
+
47
+
48
+ class ConvUpsampleShuffle(nn.Module):
49
+ def __init__(
50
+ self,
51
+ hidden_dim,
52
+ num_layers: int = 2,
53
+ expansion: int = 4,
54
+ layer_scale: float = 1.0,
55
+ kernel_size: int = 7,
56
+ **kwargs,
57
+ ):
58
+ super().__init__()
59
+ self.convs = nn.ModuleList([])
60
+ for _ in range(num_layers):
61
+ self.convs.append(
62
+ CvnxtBlock(
63
+ hidden_dim,
64
+ kernel_size=kernel_size,
65
+ expansion=expansion,
66
+ layer_scale=layer_scale,
67
+ )
68
+ )
69
+ self.up = nn.Sequential(
70
+ nn.PixelShuffle(2),
71
+ nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
72
+ )
73
+
74
+ def forward(self, x: torch.Tensor):
75
+ for conv in self.convs:
76
+ x = conv(x)
77
+ x = self.up(x)
78
+ x = rearrange(x, "b c h w -> b (h w) c")
79
+ return x
80
+
81
+
82
+ class ConvUpsampleShuffleResidual(nn.Module):
83
+ def __init__(
84
+ self,
85
+ hidden_dim,
86
+ num_layers: int = 2,
87
+ expansion: int = 4,
88
+ layer_scale: float = 1.0,
89
+ kernel_size: int = 7,
90
+ padding_mode: str = "zeros",
91
+ **kwargs,
92
+ ):
93
+ super().__init__()
94
+ self.convs = nn.ModuleList([])
95
+ for _ in range(num_layers):
96
+ self.convs.append(
97
+ CvnxtBlock(
98
+ hidden_dim,
99
+ kernel_size=kernel_size,
100
+ expansion=expansion,
101
+ layer_scale=layer_scale,
102
+ padding_mode=padding_mode,
103
+ )
104
+ )
105
+ self.up = nn.Sequential(
106
+ nn.PixelShuffle(2),
107
+ nn.Conv2d(
108
+ hidden_dim // 4,
109
+ hidden_dim // 4,
110
+ kernel_size=7,
111
+ padding=3,
112
+ padding_mode=padding_mode,
113
+ groups=hidden_dim // 4,
114
+ ),
115
+ nn.ReLU(),
116
+ nn.Conv2d(
117
+ hidden_dim // 4,
118
+ hidden_dim // 2,
119
+ kernel_size=3,
120
+ padding=1,
121
+ padding_mode=padding_mode,
122
+ ),
123
+ )
124
+ self.residual = nn.Sequential(
125
+ nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
126
+ nn.UpsamplingBilinear2d(scale_factor=2),
127
+ )
128
+
129
+ def forward(self, x: torch.Tensor):
130
+ for conv in self.convs:
131
+ x = conv(x)
132
+ x = self.up(x) + self.residual(x)
133
+ x = rearrange(x, "b c h w -> b (h w) c")
134
+ return x
unidepth/models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .unidepthv1 import UniDepthV1
2
+ from .unidepthv2 import UniDepthV2
3
+
4
+ __all__ = [
5
+ "UniDepthV1",
6
+ "UniDepthV2",
7
+ ]
unidepth/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (319 Bytes). View file
 
unidepth/models/__pycache__/encoder.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
unidepth/models/backbones/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .convnext import ConvNeXt
2
+ from .convnext2 import ConvNeXtV2
3
+ from .dinov2 import _make_dinov2_model
4
+
5
+ __all__ = [
6
+ "ConvNeXt",
7
+ "ConvNeXtV2",
8
+ "_make_dinov2_model",
9
+ ]
unidepth/models/backbones/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (398 Bytes). View file
 
unidepth/models/backbones/__pycache__/convnext.cpython-311.pyc ADDED
Binary file (28.2 kB). View file
 
unidepth/models/backbones/__pycache__/convnext2.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
unidepth/models/backbones/__pycache__/dinov2.cpython-311.pyc ADDED
Binary file (22.4 kB). View file
 
unidepth/models/backbones/convnext.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from functools import partial
3
+ from typing import Callable, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.layers import (AvgPool2dSame, DropPath, GlobalResponseNormMlp,
8
+ LayerNorm, LayerNorm2d, Mlp, create_conv2d,
9
+ get_act_layer, make_divisible, to_ntuple,
10
+ trunc_normal_)
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+
14
+ def get_num_layer_for_convnext(var_name):
15
+ """
16
+ Divide [3, 3, 27, 3] layers into 12 groups; each group is three
17
+ consecutive blocks, including possible neighboring downsample layers;
18
+ adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
19
+ """
20
+ if var_name.startswith("downsample_layers"):
21
+ stage_id = int(var_name.split(".")[1])
22
+ if stage_id == 0:
23
+ layer_id = 0
24
+ elif stage_id == 1 or stage_id == 2:
25
+ layer_id = stage_id + 1
26
+ elif stage_id == 3:
27
+ layer_id = 12
28
+
29
+ elif var_name.startswith("stages"):
30
+ stage_id = int(var_name.split(".")[1])
31
+ block_id = int(var_name.split(".")[3])
32
+ if stage_id == 0 or stage_id == 1:
33
+ layer_id = stage_id + 1
34
+ elif stage_id == 2:
35
+ layer_id = 3 + block_id // 3
36
+ elif stage_id == 3:
37
+ layer_id = 12
38
+
39
+ elif var_name.startswith("stem"):
40
+ return 0
41
+ else:
42
+ layer_id = 12
43
+ return layer_id + 1
44
+
45
+
46
+ def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None):
47
+ parameter_group_names = {}
48
+ parameter_group_vars = {}
49
+ skip = set()
50
+ if skip_list is not None:
51
+ skip = skip_list
52
+ if hasattr(model, "no_weight_decay"):
53
+ skip.update(model.no_weight_decay())
54
+ num_layers = 12
55
+ layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
56
+ for name, param in model.named_parameters():
57
+ if not param.requires_grad:
58
+ continue # frozen weights
59
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip:
60
+ group_name = "no_decay"
61
+ this_wd = 0.0
62
+ else:
63
+ group_name = "decay"
64
+ this_wd = wd
65
+
66
+ layer_id = get_num_layer_for_convnext(name)
67
+ group_name = "layer_%d_%s" % (layer_id, group_name)
68
+
69
+ if group_name not in parameter_group_names:
70
+ scale = layer_scale[layer_id]
71
+ cur_lr = lr * scale
72
+
73
+ parameter_group_names[group_name] = {
74
+ "weight_decay": this_wd,
75
+ "weight_decay_init": this_wd,
76
+ "weight_decay_base": this_wd,
77
+ "params": [],
78
+ "lr_init": cur_lr,
79
+ "lr_base": lr,
80
+ "lr": cur_lr,
81
+ }
82
+ parameter_group_vars[group_name] = {
83
+ "weight_decay": this_wd,
84
+ "weight_decay_init": this_wd,
85
+ "weight_decay_base": this_wd,
86
+ "params": [],
87
+ "lr_init": cur_lr,
88
+ "lr_base": lr,
89
+ "lr": cur_lr,
90
+ }
91
+ if this_wd == 0.0:
92
+ parameter_group_names[group_name]["weight_decay_final"] = 0.0
93
+ parameter_group_vars[group_name]["weight_decay_final"] = 0.0
94
+ parameter_group_vars[group_name]["params"].append(param)
95
+ parameter_group_names[group_name]["params"].append(name)
96
+ # from unidepth.utils import is_main_process
97
+ # import json
98
+ # if is_main_process():
99
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
100
+ return list(parameter_group_vars.values()), [
101
+ v["lr"] for k, v in parameter_group_vars.items()
102
+ ]
103
+
104
+
105
+ class Downsample(nn.Module):
106
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
107
+ super().__init__()
108
+ avg_stride = stride if dilation == 1 else 1
109
+ if stride > 1 or dilation > 1:
110
+ avg_pool_fn = (
111
+ AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
112
+ )
113
+ self.pool = avg_pool_fn(
114
+ 2, avg_stride, ceil_mode=True, count_include_pad=False
115
+ )
116
+ else:
117
+ self.pool = nn.Identity()
118
+
119
+ if in_chs != out_chs:
120
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
121
+ else:
122
+ self.conv = nn.Identity()
123
+
124
+ def forward(self, x):
125
+ x = self.pool(x)
126
+ x = self.conv(x)
127
+ return x
128
+
129
+
130
+ class ConvNeXtBlock(nn.Module):
131
+ """ConvNeXt Block
132
+ There are two equivalent implementations:
133
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
134
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
135
+
136
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
137
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
138
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ in_chs: int,
144
+ out_chs: Optional[int] = None,
145
+ kernel_size: int = 7,
146
+ stride: int = 1,
147
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
148
+ mlp_ratio: float = 4,
149
+ conv_mlp: bool = False,
150
+ conv_bias: bool = True,
151
+ use_grn: bool = False,
152
+ ls_init_value: Optional[float] = 1e-6,
153
+ act_layer: Union[str, Callable] = "gelu",
154
+ norm_layer: Optional[Callable] = None,
155
+ drop_path: float = 0.0,
156
+ ):
157
+ """
158
+
159
+ Args:
160
+ in_chs: Block input channels.
161
+ out_chs: Block output channels (same as in_chs if None).
162
+ kernel_size: Depthwise convolution kernel size.
163
+ stride: Stride of depthwise convolution.
164
+ dilation: Tuple specifying input and output dilation of block.
165
+ mlp_ratio: MLP expansion ratio.
166
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
167
+ conv_bias: Apply bias for all convolution (linear) layers.
168
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
169
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
170
+ act_layer: Activation layer.
171
+ norm_layer: Normalization layer (defaults to LN if not specified).
172
+ drop_path: Stochastic depth probability.
173
+ """
174
+ super().__init__()
175
+ out_chs = out_chs or in_chs
176
+ dilation = to_ntuple(2)(dilation)
177
+ act_layer = get_act_layer(act_layer)
178
+ if not norm_layer:
179
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
180
+ mlp_layer = partial(
181
+ GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp
182
+ )
183
+ self.use_conv_mlp = conv_mlp
184
+ self.conv_dw = create_conv2d(
185
+ in_chs,
186
+ out_chs,
187
+ kernel_size=kernel_size,
188
+ stride=stride,
189
+ dilation=dilation[0],
190
+ depthwise=True,
191
+ bias=conv_bias,
192
+ )
193
+ self.norm = norm_layer(out_chs)
194
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
195
+ self.gamma = (
196
+ nn.Parameter(ls_init_value * torch.ones(out_chs))
197
+ if ls_init_value is not None
198
+ else None
199
+ )
200
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
201
+ self.shortcut = Downsample(
202
+ in_chs, out_chs, stride=stride, dilation=dilation[0]
203
+ )
204
+ else:
205
+ self.shortcut = nn.Identity()
206
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
207
+
208
+ def forward(self, x):
209
+ shortcut = x
210
+ x = self.conv_dw(x.contiguous())
211
+ if self.use_conv_mlp:
212
+ x = self.norm(x)
213
+ x = self.mlp(x)
214
+ else:
215
+ x = x.permute(0, 2, 3, 1).contiguous()
216
+ x = self.norm(x)
217
+ x = self.mlp(x)
218
+ x = x.permute(0, 3, 1, 2).contiguous()
219
+ if self.gamma is not None:
220
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
221
+
222
+ x = self.drop_path(x) + self.shortcut(shortcut)
223
+ return x.contiguous()
224
+
225
+
226
+ class ConvNeXtStage(nn.Module):
227
+ def __init__(
228
+ self,
229
+ in_chs,
230
+ out_chs,
231
+ kernel_size=7,
232
+ stride=2,
233
+ depth=2,
234
+ dilation=(1, 1),
235
+ drop_path_rates=None,
236
+ ls_init_value=1.0,
237
+ conv_mlp=False,
238
+ conv_bias=True,
239
+ use_grn=False,
240
+ act_layer="gelu",
241
+ norm_layer=None,
242
+ norm_layer_cl=None,
243
+ ):
244
+ super().__init__()
245
+ self.grad_checkpointing = False
246
+
247
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
248
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
249
+ pad = (
250
+ "same" if dilation[1] > 1 else 0
251
+ ) # same padding needed if dilation used
252
+ self.downsample = nn.Sequential(
253
+ norm_layer(in_chs),
254
+ create_conv2d(
255
+ in_chs,
256
+ out_chs,
257
+ kernel_size=ds_ks,
258
+ stride=stride,
259
+ dilation=dilation[0],
260
+ padding=pad,
261
+ bias=conv_bias,
262
+ ),
263
+ )
264
+ in_chs = out_chs
265
+ else:
266
+ self.downsample = nn.Identity()
267
+
268
+ drop_path_rates = drop_path_rates or [0.0] * depth
269
+ stage_blocks = []
270
+ for i in range(depth):
271
+ stage_blocks.append(
272
+ ConvNeXtBlock(
273
+ in_chs=in_chs,
274
+ out_chs=out_chs,
275
+ kernel_size=kernel_size,
276
+ dilation=dilation[1],
277
+ drop_path=drop_path_rates[i],
278
+ ls_init_value=ls_init_value,
279
+ conv_mlp=conv_mlp,
280
+ conv_bias=conv_bias,
281
+ use_grn=use_grn,
282
+ act_layer=act_layer,
283
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
284
+ )
285
+ )
286
+ in_chs = out_chs
287
+ self.blocks = nn.ModuleList(stage_blocks)
288
+
289
+ def forward(self, x):
290
+ xs = []
291
+ x = self.downsample(x)
292
+ for block in self.blocks:
293
+ if self.grad_checkpointing:
294
+ x = checkpoint(block, x)
295
+ else:
296
+ x = block(x)
297
+ xs.append(x)
298
+ return xs
299
+
300
+
301
+ class ConvNeXt(nn.Module):
302
+ def __init__(
303
+ self,
304
+ in_chans: int = 3,
305
+ output_stride: int = 32,
306
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
307
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
308
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
309
+ ls_init_value: Optional[float] = 1e-6,
310
+ stem_type: str = "patch",
311
+ patch_size: int = 4,
312
+ conv_mlp: bool = False,
313
+ conv_bias: bool = True,
314
+ use_grn: bool = False,
315
+ act_layer: Union[str, Callable] = "gelu",
316
+ norm_layer: Optional[Union[str, Callable]] = None,
317
+ norm_eps: Optional[float] = None,
318
+ drop_path_rate: float = 0.0,
319
+ output_idx=[],
320
+ use_checkpoint=False,
321
+ ):
322
+ """
323
+ Args:
324
+ in_chans: Number of input image channels.
325
+ num_classes: Number of classes for classification head.
326
+ global_pool: Global pooling type.
327
+ output_stride: Output stride of network, one of (8, 16, 32).
328
+ depths: Number of blocks at each stage.
329
+ dims: Feature dimension at each stage.
330
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
331
+ ls_init_value: Init value for Layer Scale, disabled if None.
332
+ stem_type: Type of stem.
333
+ patch_size: Stem patch size for patch stem.
334
+ head_init_scale: Init scaling value for classifier weights and biases.
335
+ head_norm_first: Apply normalization before global pool + head.
336
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
337
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
338
+ conv_bias: Use bias layers w/ all convolutions.
339
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
340
+ act_layer: Activation layer type.
341
+ norm_layer: Normalization layer type.
342
+ drop_rate: Head pre-classifier dropout rate.
343
+ drop_path_rate: Stochastic depth drop rate.
344
+ """
345
+ super().__init__()
346
+ self.num_layers = len(depths)
347
+ self.depths = output_idx
348
+ self.embed_dims = [
349
+ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
350
+ ]
351
+ self.embed_dim = dims[0]
352
+
353
+ assert output_stride in (8, 16, 32)
354
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
355
+ if norm_layer is None:
356
+ norm_layer = LayerNorm2d
357
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
358
+ if norm_eps is not None:
359
+ norm_layer = partial(norm_layer, eps=norm_eps)
360
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
361
+ else:
362
+ assert (
363
+ conv_mlp
364
+ ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
365
+ norm_layer_cl = norm_layer
366
+ if norm_eps is not None:
367
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
368
+
369
+ self.feature_info = []
370
+
371
+ assert stem_type in ("patch", "overlap", "overlap_tiered")
372
+ if stem_type == "patch":
373
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
374
+ self.stem = nn.Sequential(
375
+ nn.Conv2d(
376
+ in_chans,
377
+ dims[0],
378
+ kernel_size=patch_size,
379
+ stride=patch_size,
380
+ bias=conv_bias,
381
+ ),
382
+ norm_layer(dims[0]),
383
+ )
384
+ stem_stride = patch_size
385
+ else:
386
+ mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0]
387
+ self.stem = nn.Sequential(
388
+ nn.Conv2d(
389
+ in_chans,
390
+ mid_chs,
391
+ kernel_size=3,
392
+ stride=2,
393
+ padding=1,
394
+ bias=conv_bias,
395
+ ),
396
+ nn.Conv2d(
397
+ mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias
398
+ ),
399
+ norm_layer(dims[0]),
400
+ )
401
+ stem_stride = 4
402
+
403
+ self.stages = nn.Sequential()
404
+ dp_rates = [
405
+ x.tolist()
406
+ for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
407
+ ]
408
+ stages = []
409
+ prev_chs = dims[0]
410
+ curr_stride = stem_stride
411
+ dilation = 1
412
+ # 4 feature resolution stages, each consisting of multiple residual blocks
413
+ for i in range(4):
414
+ stride = 2 if curr_stride == 2 or i > 0 else 1
415
+ if curr_stride >= output_stride and stride > 1:
416
+ dilation *= stride
417
+ stride = 1
418
+ curr_stride *= stride
419
+ first_dilation = 1 if dilation in (1, 2) else 2
420
+ out_chs = dims[i]
421
+ stages.append(
422
+ ConvNeXtStage(
423
+ prev_chs,
424
+ out_chs,
425
+ kernel_size=kernel_sizes[i],
426
+ stride=stride,
427
+ dilation=(first_dilation, dilation),
428
+ depth=depths[i],
429
+ drop_path_rates=dp_rates[i],
430
+ ls_init_value=ls_init_value,
431
+ conv_mlp=conv_mlp,
432
+ conv_bias=conv_bias,
433
+ use_grn=use_grn,
434
+ act_layer=act_layer,
435
+ norm_layer=norm_layer,
436
+ norm_layer_cl=norm_layer_cl,
437
+ )
438
+ )
439
+ prev_chs = out_chs
440
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
441
+ self.feature_info += [
442
+ dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
443
+ ]
444
+ self.stages = nn.ModuleList(stages)
445
+ self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
446
+ self.num_features = prev_chs
447
+ self.apply(self._init_weights)
448
+ self.set_grad_checkpointing(use_checkpoint)
449
+
450
+ def _init_weights(self, module):
451
+ if isinstance(module, nn.Conv2d):
452
+ trunc_normal_(module.weight, std=0.02)
453
+ if module.bias is not None:
454
+ nn.init.zeros_(module.bias)
455
+ elif isinstance(module, nn.Linear):
456
+ trunc_normal_(module.weight, std=0.02)
457
+ nn.init.zeros_(module.bias)
458
+
459
+ def forward(self, x, masks=None):
460
+ outs = []
461
+ x = self.stem(x)
462
+ if masks is not None:
463
+ masks = torch.nn.functional.interpolate(
464
+ masks.float(), size=x.shape[-2:], mode="nearest"
465
+ )
466
+ x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous()
467
+ for stage in self.stages:
468
+ xs = stage(x)
469
+ outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs])
470
+ x = xs[-1]
471
+ return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
472
+
473
+ @torch.jit.ignore
474
+ def group_matcher(self, coarse=False):
475
+ return dict(
476
+ stem=r"^stem",
477
+ blocks=(
478
+ r"^stages\.(\d+)"
479
+ if coarse
480
+ else [
481
+ (r"^stages\.(\d+)\.downsample", (0,)), # blocks
482
+ (r"^stages\.(\d+)\.blocks\.(\d+)", None),
483
+ (r"^norm_pre", (99999,)),
484
+ ]
485
+ ),
486
+ )
487
+
488
+ @torch.jit.ignore
489
+ def set_grad_checkpointing(self, enable=True):
490
+ for s in self.stages:
491
+ s.grad_checkpointing = enable
492
+
493
+ def freeze(self) -> None:
494
+ for module in self.modules():
495
+ module.eval()
496
+ for parameters in self.parameters():
497
+ parameters.requires_grad = False
498
+
499
+ def get_params(self, lr, wd, ld, *args, **kwargs):
500
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
501
+ return encoder_p, encoder_lr
502
+
503
+ def no_weight_decay(self):
504
+ return {"mask_token"}
505
+
506
+ @classmethod
507
+ def build(cls, config):
508
+ obj = globals()[config["model"]["encoder"]["name"]](config)
509
+ return obj
510
+
511
+
512
+ def checkpoint_filter_fn(state_dict, model):
513
+ """Remap FB checkpoints -> timm"""
514
+ if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict:
515
+ return state_dict # non-FB checkpoint
516
+ if "model" in state_dict:
517
+ state_dict = state_dict["model"]
518
+
519
+ out_dict = {}
520
+ if "visual.trunk.stem.0.weight" in state_dict:
521
+ out_dict = {
522
+ k.replace("visual.trunk.", ""): v
523
+ for k, v in state_dict.items()
524
+ if k.startswith("visual.trunk.")
525
+ }
526
+ if "visual.head.proj.weight" in state_dict:
527
+ out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"]
528
+ out_dict["head.fc.bias"] = torch.zeros(
529
+ state_dict["visual.head.proj.weight"].shape[0]
530
+ )
531
+ elif "visual.head.mlp.fc1.weight" in state_dict:
532
+ out_dict["head.pre_logits.fc.weight"] = state_dict[
533
+ "visual.head.mlp.fc1.weight"
534
+ ]
535
+ out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"]
536
+ out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"]
537
+ out_dict["head.fc.bias"] = torch.zeros(
538
+ state_dict["visual.head.mlp.fc2.weight"].shape[0]
539
+ )
540
+ return out_dict
541
+
542
+ import re
543
+
544
+ for k, v in state_dict.items():
545
+ k = k.replace("downsample_layers.0.", "stem.")
546
+ k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
547
+ k = re.sub(
548
+ r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
549
+ )
550
+ k = k.replace("dwconv", "conv_dw")
551
+ k = k.replace("pwconv", "mlp.fc")
552
+ if "grn" in k:
553
+ k = k.replace("grn.beta", "mlp.grn.bias")
554
+ k = k.replace("grn.gamma", "mlp.grn.weight")
555
+ v = v.reshape(v.shape[-1])
556
+ k = k.replace("head.", "head.fc.")
557
+ if k.startswith("norm."):
558
+ k = k.replace("norm", "head.norm")
559
+ if v.ndim == 2 and "head" not in k:
560
+ model_shape = model.state_dict()[k].shape
561
+ v = v.reshape(model_shape)
562
+ out_dict[k] = v
563
+
564
+ return out_dict
565
+
566
+
567
+ HF_URL = {
568
+ "convnext_xxlarge_pt": (
569
+ "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup",
570
+ "open_clip_pytorch_model.bin",
571
+ ),
572
+ "convnext_large_pt": (
573
+ "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
574
+ "open_clip_pytorch_model.bin",
575
+ ),
576
+ "convnext_large": (
577
+ "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384",
578
+ "pytorch_model.bin",
579
+ ),
580
+ }
unidepth/models/backbones/convnext2.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.models.layers import DropPath, trunc_normal_
5
+
6
+
7
+ def get_num_layer_for_convnext_single(var_name, depths):
8
+ """
9
+ Each layer is assigned distinctive layer ids
10
+ """
11
+ if var_name.startswith("downsample_layers"):
12
+ stage_id = int(var_name.split(".")[1])
13
+ layer_id = sum(depths[:stage_id]) + 1
14
+ return layer_id
15
+
16
+ elif var_name.startswith("stages"):
17
+ stage_id = int(var_name.split(".")[1])
18
+ block_id = int(var_name.split(".")[2])
19
+ layer_id = sum(depths[:stage_id]) + block_id + 1
20
+ return layer_id
21
+
22
+ else:
23
+ return sum(depths) + 1
24
+
25
+
26
+ def get_num_layer_for_convnext(var_name):
27
+ """
28
+ Divide [3, 3, 27, 3] layers into 12 groups; each group is three
29
+ consecutive blocks, including possible neighboring downsample layers;
30
+ adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
31
+ """
32
+ num_max_layer = 12
33
+ if var_name.startswith("downsample_layers"):
34
+ stage_id = int(var_name.split(".")[1])
35
+ if stage_id == 0:
36
+ layer_id = 0
37
+ elif stage_id == 1 or stage_id == 2:
38
+ layer_id = stage_id + 1
39
+ elif stage_id == 3:
40
+ layer_id = 12
41
+ return layer_id
42
+
43
+ elif var_name.startswith("stages"):
44
+ stage_id = int(var_name.split(".")[1])
45
+ block_id = int(var_name.split(".")[2])
46
+ if stage_id == 0 or stage_id == 1:
47
+ layer_id = stage_id + 1
48
+ elif stage_id == 2:
49
+ layer_id = 3 + block_id // 3
50
+ elif stage_id == 3:
51
+ layer_id = 12
52
+ return layer_id
53
+ else:
54
+ return num_max_layer + 1
55
+
56
+
57
+ def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
58
+ parameter_group_names = {}
59
+ parameter_group_vars = {}
60
+ skip = {}
61
+ if skip_list is not None:
62
+ skip = skip_list
63
+ elif hasattr(model, "no_weight_decay"):
64
+ skip = model.no_weight_decay()
65
+ num_layers = 12 # sum(model.depths)
66
+ layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
67
+ for name, param in model.named_parameters():
68
+ if not param.requires_grad:
69
+ continue # frozen weights
70
+ if (
71
+ len(param.shape) == 1
72
+ or name.endswith(".bias")
73
+ or name in skip
74
+ or name.endswith(".gamma")
75
+ or name.endswith(".beta")
76
+ ):
77
+ group_name = "no_decay"
78
+ this_weight_decay = 0.0
79
+ else:
80
+ group_name = "decay"
81
+ this_weight_decay = wd
82
+
83
+ # layer_id = get_num_layer_for_convnext_single(name, model.depths)
84
+ layer_id = get_num_layer_for_convnext(name)
85
+ group_name = "layer_%d_%s" % (layer_id, group_name)
86
+
87
+ if group_name not in parameter_group_names:
88
+ scale = layer_scale[layer_id]
89
+ cur_lr = lr * scale
90
+
91
+ parameter_group_names[group_name] = {
92
+ "weight_decay": this_weight_decay,
93
+ "params": [],
94
+ "lr_scale": scale,
95
+ "lr": cur_lr,
96
+ }
97
+ parameter_group_vars[group_name] = {
98
+ "weight_decay": this_weight_decay,
99
+ "params": [],
100
+ "lr_scale": scale,
101
+ "lr": cur_lr,
102
+ }
103
+ parameter_group_vars[group_name]["params"].append(param)
104
+ parameter_group_names[group_name]["params"].append(name)
105
+ # if is_main_process():
106
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
107
+ return list(parameter_group_vars.values()), [
108
+ v["lr"] for k, v in parameter_group_vars.items()
109
+ ]
110
+
111
+
112
+ class LayerNorm(nn.Module):
113
+ """LayerNorm that supports two data formats: channels_last (default) or channels_first.
114
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
115
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
116
+ with shape (batch_size, channels, height, width).
117
+ """
118
+
119
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
120
+ super().__init__()
121
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
122
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
123
+ self.eps = eps
124
+ self.data_format = data_format
125
+ if self.data_format not in ["channels_last", "channels_first"]:
126
+ raise NotImplementedError
127
+ self.normalized_shape = (normalized_shape,)
128
+
129
+ def forward(self, x):
130
+ if self.data_format == "channels_last":
131
+ return F.layer_norm(
132
+ x, self.normalized_shape, self.weight, self.bias, self.eps
133
+ )
134
+ elif self.data_format == "channels_first":
135
+ u = x.mean(1, keepdim=True)
136
+ s = (x - u).pow(2).mean(1, keepdim=True)
137
+ x = (x - u) / torch.sqrt(s + self.eps)
138
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
139
+ return x
140
+
141
+
142
+ class GRN(nn.Module):
143
+ """GRN (Global Response Normalization) layer"""
144
+
145
+ def __init__(self, dim):
146
+ super().__init__()
147
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
148
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
149
+
150
+ def forward(self, x):
151
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
152
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
153
+ return self.gamma * (x * Nx) + self.beta + x
154
+
155
+
156
+ class Block(nn.Module):
157
+ """ConvNeXtV2 Block.
158
+
159
+ Args:
160
+ dim (int): Number of input channels.
161
+ drop_path (float): Stochastic depth rate. Default: 0.0
162
+ """
163
+
164
+ def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False):
165
+ super().__init__()
166
+ self.dwconv = nn.Conv2d(
167
+ dim, dim, kernel_size=7, padding=3, groups=dim
168
+ ) # depthwise conv
169
+ self.norm = LayerNorm(dim, eps=1e-6)
170
+ self.pwconv1 = nn.Linear(
171
+ dim, mult * dim
172
+ ) # pointwise/1x1 convs, implemented with linear layers
173
+ self.act = nn.GELU()
174
+ self.grn = GRN(mult * dim)
175
+ self.pwconv2 = nn.Linear(mult * dim, dim)
176
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
177
+ self.use_checkpoint = use_checkpoint
178
+
179
+ def forward(self, x):
180
+ input = x
181
+ x = self.dwconv(x)
182
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
183
+ x = self.norm(x)
184
+ x = self.pwconv1(x)
185
+ x = self.act(x)
186
+ x = self.grn(x)
187
+ x = self.pwconv2(x)
188
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
189
+
190
+ x = input + self.drop_path(x)
191
+ return x
192
+
193
+
194
+ class ConvNeXtV2(nn.Module):
195
+ """ConvNeXt V2
196
+
197
+ Args:
198
+ in_chans (int): Number of input image channels. Default: 3
199
+ num_classes (int): Number of classes for classification head. Default: 1000
200
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
201
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
202
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
203
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ in_chans=3,
209
+ depths=[3, 3, 9, 3],
210
+ dims=96,
211
+ drop_path_rate=0.0,
212
+ output_idx=[],
213
+ use_checkpoint=False,
214
+ ):
215
+ super().__init__()
216
+ self.num_layers = len(depths)
217
+ self.depths = output_idx
218
+ self.embed_dims = [
219
+ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
220
+ ]
221
+ self.embed_dim = dims[0]
222
+
223
+ self.downsample_layers = (
224
+ nn.ModuleList()
225
+ ) # stem and 3 intermediate downsampling conv layers
226
+ stem = nn.Sequential(
227
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
228
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
229
+ )
230
+ self.downsample_layers.append(stem)
231
+ for i in range(3):
232
+ downsample_layer = nn.Sequential(
233
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
234
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
235
+ )
236
+ self.downsample_layers.append(downsample_layer)
237
+
238
+ self.stages = (
239
+ nn.ModuleList()
240
+ ) # 4 feature resolution stages, each consisting of multiple residual blocks
241
+ self.out_norms = nn.ModuleList()
242
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
243
+ cur = 0
244
+ for i in range(4):
245
+ stage = nn.ModuleList(
246
+ [
247
+ Block(
248
+ dim=dims[i],
249
+ drop_path=dp_rates[cur + j],
250
+ use_checkpoint=use_checkpoint,
251
+ )
252
+ for j in range(depths[i])
253
+ ]
254
+ )
255
+ self.stages.append(stage)
256
+ cur += depths[i]
257
+
258
+ self.apply(self._init_weights)
259
+
260
+ def _init_weights(self, m):
261
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
262
+ trunc_normal_(m.weight, std=0.02)
263
+ nn.init.constant_(m.bias, 0)
264
+
265
+ def forward(self, x):
266
+ outs = []
267
+ for i in range(4):
268
+ x = self.downsample_layers[i](x)
269
+ for stage in self.stages[i]:
270
+ x = stage(x)
271
+ outs.append(x.permute(0, 2, 3, 1))
272
+ cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
273
+ return outs, cls_tokens
274
+
275
+ def get_params(self, lr, wd, ld, *args, **kwargs):
276
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
277
+ return encoder_p, encoder_lr
278
+
279
+ def freeze(self) -> None:
280
+ for module in self.modules():
281
+ module.eval()
282
+ for parameters in self.parameters():
283
+ parameters.requires_grad = False
284
+
285
+ @classmethod
286
+ def build(cls, config):
287
+ obj = globals()[config["model"]["encoder"]["name"]](config)
288
+ return obj
unidepth/models/backbones/dinov2.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from functools import partial
4
+ from typing import Callable, Sequence
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+
10
+ from .metadinov2 import Attention, MemEffAttention, Mlp
11
+ from .metadinov2 import NestedTensorBlock as Block
12
+ from .metadinov2 import PatchEmbed, SwiGLUFFNFused
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+ logger = logging.getLogger("dinov2")
16
+
17
+
18
+ def named_apply(
19
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
20
+ ) -> nn.Module:
21
+ if not depth_first and include_root:
22
+ fn(module=module, name=name)
23
+ for child_name, child_module in module.named_children():
24
+ child_name = ".".join((name, child_name)) if name else child_name
25
+ named_apply(
26
+ fn=fn,
27
+ module=child_module,
28
+ name=child_name,
29
+ depth_first=depth_first,
30
+ include_root=True,
31
+ )
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
38
+ parameter_group_names = {}
39
+ parameter_group_vars = {}
40
+ skip = {}
41
+ if skip_list is not None:
42
+ skip = skip_list
43
+ elif hasattr(model, "no_weight_decay"):
44
+ skip = model.no_weight_decay()
45
+
46
+ num_layers = model.n_blocks
47
+ layer_scale = list(ld ** (num_layers - i) for i in range(num_layers))
48
+
49
+ for name, param in model.named_parameters():
50
+ if not param.requires_grad:
51
+ continue
52
+
53
+ if len(param.shape) == 1: # norm
54
+ group_name = "no_decay"
55
+ this_wd = 0.0
56
+ # layer scale, bias beta?
57
+ elif (
58
+ name in skip
59
+ or name.endswith(".gamma")
60
+ or name.endswith(".beta")
61
+ or name.endswith(".bias")
62
+ ):
63
+ group_name = "no_decay"
64
+ this_wd = 0.0
65
+ elif "cls_token" in name or "pos_embed" in name or "mask_token" in name:
66
+ group_name = "no_decay"
67
+ this_wd = 0.0
68
+ else:
69
+ group_name = "decay"
70
+ this_wd = wd
71
+
72
+ if name.startswith("blocks"):
73
+ layer_id = int(name.split(".")[1])
74
+ elif name.startswith("patch_embed"):
75
+ layer_id = 0
76
+ else:
77
+ layer_id = 0
78
+
79
+ group_name = f"layer_{layer_id}_{group_name}"
80
+
81
+ if group_name not in parameter_group_names:
82
+ scale = layer_scale[layer_id]
83
+ cur_lr = lr * scale
84
+
85
+ parameter_group_names[group_name] = {
86
+ "weight_decay": this_wd,
87
+ "params": [],
88
+ "lr_init": cur_lr,
89
+ "lr_base": lr,
90
+ "lr": cur_lr,
91
+ }
92
+ parameter_group_vars[group_name] = {
93
+ "weight_decay": this_wd,
94
+ "params": [],
95
+ "lr_init": cur_lr,
96
+ "lr_base": lr,
97
+ "lr": cur_lr,
98
+ }
99
+ parameter_group_vars[group_name]["params"].append(param)
100
+ parameter_group_names[group_name]["params"].append(name)
101
+ return list(parameter_group_vars.values()), [
102
+ v["lr"] for k, v in parameter_group_vars.items()
103
+ ]
104
+
105
+
106
+ class BlockChunk(nn.ModuleList):
107
+ def forward(self, x):
108
+ for b in self:
109
+ x = b(x)
110
+ return x
111
+
112
+
113
+ class DinoVisionTransformer(nn.Module):
114
+ def __init__(
115
+ self,
116
+ img_size=224,
117
+ patch_size=16,
118
+ in_chans=3,
119
+ embed_dim=768,
120
+ depth=12,
121
+ num_heads=12,
122
+ mlp_ratio=4.0,
123
+ qkv_bias=True,
124
+ ffn_bias=True,
125
+ proj_bias=True,
126
+ drop_path_rate=0.0,
127
+ drop_path_uniform=False,
128
+ init_values=None, # for layerscale: None or 0 => no layerscale
129
+ embed_layer=PatchEmbed,
130
+ act_layer=nn.GELU,
131
+ block_fn=Block,
132
+ ffn_layer="mlp",
133
+ block_chunks=1,
134
+ output_idx=[5, 12, 18, 24],
135
+ checkpoint: bool = False,
136
+ num_register_tokens=0,
137
+ interpolate_antialias=False,
138
+ interpolate_offset=0.0,
139
+ use_norm=False,
140
+ ):
141
+ """
142
+ Args:
143
+ img_size (int, tuple): input image size
144
+ patch_size (int, tuple): patch size
145
+ in_chans (int): number of input channels
146
+ embed_dim (int): embedding dimension
147
+ depth (int): depth of transformer
148
+ num_heads (int): number of attention heads
149
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
150
+ qkv_bias (bool): enable bias for qkv if True
151
+ proj_bias (bool): enable bias for proj in attn if True
152
+ ffn_bias (bool): enable bias for ffn if True
153
+ drop_path_rate (float): stochastic depth rate
154
+ drop_path_uniform (bool): apply uniform drop rate across blocks
155
+ weight_init (str): weight init scheme
156
+ init_values (float): layer-scale init values
157
+ embed_layer (nn.Module): patch embedding layer
158
+ act_layer (nn.Module): MLP activation layer
159
+ block_fn (nn.Module): transformer block class
160
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
161
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
162
+ """
163
+ super().__init__()
164
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
165
+
166
+ self.num_features = self.embed_dim = (
167
+ embed_dim # num_features for consistency with other models
168
+ )
169
+ self.embed_dims = [embed_dim] * output_idx[-1]
170
+ self.num_tokens = 1
171
+ self.n_blocks = depth
172
+ self.num_heads = num_heads
173
+ self.patch_size = patch_size
174
+ self.depths = output_idx
175
+ self.checkpoint = checkpoint
176
+ self.num_register_tokens = num_register_tokens
177
+ self.interpolate_antialias = interpolate_antialias
178
+ self.interpolate_offset = interpolate_offset
179
+
180
+ self.patch_embed = embed_layer(
181
+ img_size=img_size,
182
+ patch_size=patch_size,
183
+ in_chans=in_chans,
184
+ embed_dim=embed_dim,
185
+ )
186
+ num_patches = self.patch_embed.num_patches
187
+
188
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
189
+ self.pos_embed = nn.Parameter(
190
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
191
+ )
192
+ assert num_register_tokens >= 0
193
+ self.register_tokens = nn.Parameter(
194
+ torch.zeros(1, max(1, num_register_tokens), embed_dim)
195
+ )
196
+
197
+ if drop_path_uniform is True:
198
+ dpr = [drop_path_rate] * depth
199
+ else:
200
+ dpr = [
201
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
202
+ ] # stochastic depth decay rule
203
+
204
+ if ffn_layer == "mlp":
205
+ logger.info("using MLP layer as FFN")
206
+ ffn_layer = Mlp
207
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
208
+ logger.info("using SwiGLU layer as FFN")
209
+ ffn_layer = SwiGLUFFNFused
210
+ elif ffn_layer == "identity":
211
+ logger.info("using Identity layer as FFN")
212
+
213
+ def f(*args, **kwargs):
214
+ return nn.Identity()
215
+
216
+ ffn_layer = f
217
+ else:
218
+ raise NotImplementedError
219
+
220
+ blocks_list = [
221
+ block_fn(
222
+ dim=embed_dim,
223
+ num_heads=num_heads,
224
+ mlp_ratio=mlp_ratio,
225
+ qkv_bias=qkv_bias,
226
+ proj_bias=proj_bias,
227
+ ffn_bias=ffn_bias,
228
+ drop_path=dpr[i],
229
+ norm_layer=norm_layer,
230
+ act_layer=act_layer,
231
+ ffn_layer=ffn_layer,
232
+ init_values=init_values,
233
+ )
234
+ for i in range(depth)
235
+ ]
236
+ if block_chunks > 0:
237
+ self.chunked_blocks = True
238
+ chunked_blocks = []
239
+ chunksize = depth // block_chunks
240
+ for i in range(0, depth, chunksize):
241
+ # this is to keep the block index consistent if we chunk the block list
242
+ chunked_blocks.append(
243
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
244
+ )
245
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
246
+ else:
247
+ self.chunked_blocks = False
248
+ self.blocks = nn.ModuleList(blocks_list)
249
+
250
+ self.norm = norm_layer(embed_dim)
251
+ self.use_norm = use_norm
252
+ self.head = nn.Identity()
253
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
254
+ self.init_weights()
255
+
256
+ def init_weights(self):
257
+ trunc_normal_(self.pos_embed, std=0.02)
258
+ nn.init.normal_(self.cls_token, std=1e-6)
259
+ if self.num_register_tokens:
260
+ nn.init.normal_(self.register_tokens, std=1e-6)
261
+ named_apply(init_weights_vit_timm, self)
262
+
263
+ def interpolate_pos_encoding(self, x, w, h):
264
+ previous_dtype = x.dtype
265
+ npatch = x.shape[1] - 1
266
+ N = self.pos_embed.shape[1] - 1
267
+ if npatch == N and w == h:
268
+ return self.pos_embed
269
+ pos_embed = self.pos_embed.float()
270
+ class_pos_embed = pos_embed[:, 0]
271
+ patch_pos_embed = pos_embed[:, 1:]
272
+ dim = x.shape[-1]
273
+ w0 = w // self.patch_size
274
+ h0 = h // self.patch_size
275
+
276
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
277
+ assert N == M * M
278
+ kwargs = {}
279
+ if self.interpolate_offset:
280
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
281
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
282
+ sx = float(w0 + self.interpolate_offset) / M
283
+ sy = float(h0 + self.interpolate_offset) / M
284
+ kwargs["scale_factor"] = (sx, sy)
285
+ else:
286
+ # Simply specify an output size instead of a scale factor
287
+ kwargs["size"] = (w0, h0)
288
+
289
+ patch_pos_embed = nn.functional.interpolate(
290
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
291
+ mode="bicubic",
292
+ antialias=self.interpolate_antialias,
293
+ **kwargs,
294
+ )
295
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
296
+
297
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
298
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
299
+ previous_dtype
300
+ )
301
+
302
+ def prepare_tokens_with_masks(self, x, masks=None):
303
+ B, nc, w, h = x.shape
304
+ x = self.patch_embed(x)
305
+ if masks is not None:
306
+ masks = masks.bool().view(B, -1, 1)
307
+ x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x)
308
+
309
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
310
+ x = x + self.interpolate_pos_encoding(x, w, h)
311
+
312
+ if self.num_register_tokens:
313
+ x = torch.cat(
314
+ (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
315
+ dim=1,
316
+ )
317
+ return x
318
+
319
+ def forward(self, x, masks=None):
320
+ shapes = [val // self.patch_size for val in x.shape[-2:]]
321
+ batch_size = x.shape[0]
322
+ x = self.prepare_tokens_with_masks(x, masks)
323
+ outputs = []
324
+ for i, blk in enumerate(self.blocks):
325
+ x = blk(x)
326
+ outputs.append(x)
327
+
328
+ if self.use_norm:
329
+ outputs = [self.norm(out) for out in outputs]
330
+ class_tokens = [out[:, :1] for out in outputs]
331
+ outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs]
332
+ outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs]
333
+
334
+ return (outputs, class_tokens)
335
+
336
+ def get_params(self, lr, wd, ld, *args, **kwargs):
337
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
338
+ return encoder_p, encoder_lr
339
+
340
+ def freeze(self) -> None:
341
+ for module in self.modules():
342
+ module.eval()
343
+ for parameters in self.parameters():
344
+ parameters.requires_grad = False
345
+
346
+ def train(self, mode=True):
347
+ super().train(mode)
348
+ self.mask_token.requires_grad = False
349
+ self.register_tokens.requires_grad = False
350
+
351
+
352
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
353
+ """ViT weight initialization, original timm impl (for reproducibility)"""
354
+ if isinstance(module, nn.Linear):
355
+ trunc_normal_(module.weight, std=0.02)
356
+ if module.bias is not None:
357
+ nn.init.zeros_(module.bias)
358
+
359
+
360
+ def vit_small(patch_size=16, num_register_tokens=0, export=False, **kwargs):
361
+ model = DinoVisionTransformer(
362
+ patch_size=patch_size,
363
+ embed_dim=384,
364
+ depth=12,
365
+ num_heads=6,
366
+ mlp_ratio=4,
367
+ num_register_tokens=num_register_tokens,
368
+ block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
369
+ **kwargs,
370
+ )
371
+ return model
372
+
373
+
374
+ def vit_base(patch_size=16, num_register_tokens=0, export=False, **kwargs):
375
+ model = DinoVisionTransformer(
376
+ patch_size=patch_size,
377
+ embed_dim=768,
378
+ depth=12,
379
+ num_heads=12,
380
+ mlp_ratio=4,
381
+ num_register_tokens=num_register_tokens,
382
+ block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
383
+ **kwargs,
384
+ )
385
+ return model
386
+
387
+
388
+ def vit_large(patch_size=16, num_register_tokens=0, export=False, **kwargs):
389
+ model = DinoVisionTransformer(
390
+ patch_size=patch_size,
391
+ embed_dim=1024,
392
+ depth=24,
393
+ num_heads=16,
394
+ mlp_ratio=4,
395
+ num_register_tokens=num_register_tokens,
396
+ block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
397
+ **kwargs,
398
+ )
399
+ return model
400
+
401
+
402
+ def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
403
+ compact_arch_name = arch_name.replace("_", "")[:4]
404
+ return f"dinov2_{compact_arch_name}{patch_size}"
405
+
406
+
407
+ def _make_dinov2_model(
408
+ *,
409
+ arch_name: str = "vit_large",
410
+ img_size: int = 518,
411
+ patch_size: int = 14,
412
+ init_values: float = 1.0,
413
+ ffn_layer: str = "mlp",
414
+ block_chunks: int = 0,
415
+ pretrained: str = "",
416
+ output_idx: Sequence[int] = [],
417
+ num_register_tokens: int = 0,
418
+ drop_path_rate: float = 0.0,
419
+ use_norm: bool = False,
420
+ export: bool = False,
421
+ interpolate_offset: float = 0.0,
422
+ **kwargs,
423
+ ):
424
+ model_name = _make_dinov2_model_name(arch_name, patch_size)
425
+
426
+ vit_kwargs = dict(
427
+ img_size=img_size,
428
+ patch_size=patch_size,
429
+ init_values=init_values,
430
+ ffn_layer=ffn_layer,
431
+ block_chunks=block_chunks,
432
+ output_idx=output_idx,
433
+ drop_path_rate=drop_path_rate,
434
+ num_register_tokens=num_register_tokens,
435
+ use_norm=use_norm,
436
+ export=export,
437
+ interpolate_offset=interpolate_offset,
438
+ )
439
+ vit_kwargs.update(**kwargs)
440
+ model = eval(arch_name)(**vit_kwargs)
441
+ if pretrained == "":
442
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}"
443
+ if num_register_tokens > 0:
444
+ url += "_reg4"
445
+ url += "_pretrain.pth"
446
+ state_dict = torch.hub.load_state_dict_from_url(
447
+ url, map_location="cpu", progress=False
448
+ )
449
+ info = model.load_state_dict(state_dict, strict=False)
450
+ print(info)
451
+ elif pretrained is not None:
452
+ state_dict = torch.load(pretrained, map_location="cpu")
453
+ info = model.load_state_dict(state_dict, strict=False)
454
+ print(f"loading from {pretrained} with:", info)
455
+ return model
unidepth/models/backbones/metadinov2/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .attention import Attention, MemEffAttention
8
+ from .block import NestedTensorBlock
9
+ from .dino_head import DINOHead
10
+ from .mlp import Mlp
11
+ from .patch_embed import PatchEmbed
12
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
unidepth/models/backbones/metadinov2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (592 Bytes). View file
 
unidepth/models/backbones/metadinov2/__pycache__/attention.cpython-311.pyc ADDED
Binary file (4.46 kB). View file
 
unidepth/models/backbones/metadinov2/__pycache__/block.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
unidepth/models/backbones/metadinov2/__pycache__/dino_head.cpython-311.pyc ADDED
Binary file (3.94 kB). View file
 
unidepth/models/backbones/metadinov2/__pycache__/drop_path.cpython-311.pyc ADDED
Binary file (1.86 kB). View file
 
unidepth/models/backbones/metadinov2/__pycache__/layer_scale.cpython-311.pyc ADDED
Binary file (1.62 kB). View file
 
unidepth/models/backbones/metadinov2/__pycache__/mlp.cpython-311.pyc ADDED
Binary file (2.08 kB). View file
 
unidepth/models/backbones/metadinov2/__pycache__/patch_embed.cpython-311.pyc ADDED
Binary file (4.49 kB). View file
 
unidepth/models/backbones/metadinov2/__pycache__/swiglu_ffn.cpython-311.pyc ADDED
Binary file (3.29 kB). View file
 
unidepth/models/backbones/metadinov2/attention.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+
16
+ logger = logging.getLogger("dinov2")
17
+
18
+
19
+ try:
20
+ from xformers.ops import fmha, memory_efficient_attention, unbind
21
+
22
+ XFORMERS_AVAILABLE = True
23
+ except ImportError:
24
+ logger.warning("xFormers not available")
25
+ XFORMERS_AVAILABLE = False
26
+
27
+
28
+ class Attention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ dim: int,
32
+ num_heads: int = 8,
33
+ qkv_bias: bool = False,
34
+ proj_bias: bool = True,
35
+ attn_drop: float = 0.0,
36
+ proj_drop: float = 0.0,
37
+ ) -> None:
38
+ super().__init__()
39
+ self.num_heads = num_heads
40
+ head_dim = dim // num_heads
41
+ self.scale = head_dim**-0.5
42
+
43
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
44
+ self.attn_drop = nn.Dropout(attn_drop)
45
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
46
+ self.proj_drop = nn.Dropout(proj_drop)
47
+
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ B, N, C = x.shape
50
+ qkv = (
51
+ self.qkv(x)
52
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
53
+ .permute(2, 0, 3, 1, 4)
54
+ )
55
+
56
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
57
+ attn = q @ k.transpose(-2, -1)
58
+
59
+ attn = attn.softmax(dim=-1)
60
+ attn = self.attn_drop(attn)
61
+
62
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
63
+ x = self.proj(x)
64
+ x = self.proj_drop(x)
65
+ return x
66
+
67
+
68
+ class MemEffAttention(Attention):
69
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
70
+ if not XFORMERS_AVAILABLE:
71
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
72
+ return super().forward(x)
73
+
74
+ B, N, C = x.shape
75
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
76
+
77
+ q, k, v = unbind(qkv, 2)
78
+
79
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
80
+ x = x.reshape([B, N, C])
81
+
82
+ x = self.proj(x)
83
+ x = self.proj_drop(x)
84
+ return x
unidepth/models/backbones/metadinov2/block.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Any, Callable, Dict, List, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+
25
+ try:
26
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
27
+
28
+ XFORMERS_AVAILABLE = True
29
+ except ImportError:
30
+ logger.warning("xFormers not available")
31
+ XFORMERS_AVAILABLE = False
32
+
33
+
34
+ class Block(nn.Module):
35
+ def __init__(
36
+ self,
37
+ dim: int,
38
+ num_heads: int,
39
+ mlp_ratio: float = 4.0,
40
+ qkv_bias: bool = False,
41
+ proj_bias: bool = True,
42
+ ffn_bias: bool = True,
43
+ drop: float = 0.0,
44
+ attn_drop: float = 0.0,
45
+ init_values=None,
46
+ drop_path: float = 0.0,
47
+ act_layer: Callable[..., nn.Module] = nn.GELU,
48
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
49
+ attn_class: Callable[..., nn.Module] = Attention,
50
+ ffn_layer: Callable[..., nn.Module] = Mlp,
51
+ ) -> None:
52
+ super().__init__()
53
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
54
+ self.norm1 = norm_layer(dim)
55
+ self.attn = attn_class(
56
+ dim,
57
+ num_heads=num_heads,
58
+ qkv_bias=qkv_bias,
59
+ proj_bias=proj_bias,
60
+ attn_drop=attn_drop,
61
+ proj_drop=drop,
62
+ )
63
+ self.ls1 = (
64
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ )
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = (
78
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
79
+ )
80
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
81
+
82
+ self.sample_drop_ratio = drop_path
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
86
+ return self.ls1(self.attn(self.norm1(x)))
87
+
88
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
89
+ return self.ls2(self.mlp(self.norm2(x)))
90
+
91
+ if self.training and self.sample_drop_ratio > 0.1:
92
+ # the overhead is compensated only for a drop path rate larger than 0.1
93
+ x = drop_add_residual_stochastic_depth(
94
+ x,
95
+ residual_func=attn_residual_func,
96
+ sample_drop_ratio=self.sample_drop_ratio,
97
+ )
98
+ x = drop_add_residual_stochastic_depth(
99
+ x,
100
+ residual_func=ffn_residual_func,
101
+ sample_drop_ratio=self.sample_drop_ratio,
102
+ )
103
+ elif self.training and self.sample_drop_ratio > 0.0:
104
+ x = x + self.drop_path1(attn_residual_func(x))
105
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
106
+ else:
107
+ x = x + attn_residual_func(x)
108
+ x = x + ffn_residual_func(x)
109
+ return x
110
+
111
+
112
+ def drop_add_residual_stochastic_depth(
113
+ x: torch.Tensor,
114
+ residual_func: Callable[[torch.Tensor], torch.Tensor],
115
+ sample_drop_ratio: float = 0.0,
116
+ ) -> torch.Tensor:
117
+ # 1) extract subset using permutation
118
+ b, n, d = x.shape
119
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
120
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
121
+ x_subset = x[brange]
122
+
123
+ # 2) apply residual_func to get residual
124
+ residual = residual_func(x_subset)
125
+
126
+ x_flat = x.flatten(1)
127
+ residual = residual.flatten(1)
128
+
129
+ residual_scale_factor = b / sample_subset_size
130
+
131
+ # 3) add the residual
132
+ x_plus_residual = torch.index_add(
133
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
134
+ )
135
+ return x_plus_residual.view_as(x)
136
+
137
+
138
+ def get_branges_scales(x, sample_drop_ratio=0.0):
139
+ b, n, d = x.shape
140
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
141
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
142
+ residual_scale_factor = b / sample_subset_size
143
+ return brange, residual_scale_factor
144
+
145
+
146
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
147
+ if scaling_vector is None:
148
+ x_flat = x.flatten(1)
149
+ residual = residual.flatten(1)
150
+ x_plus_residual = torch.index_add(
151
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
152
+ )
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x,
156
+ brange,
157
+ residual.to(dtype=x.dtype),
158
+ scaling=scaling_vector,
159
+ alpha=residual_scale_factor,
160
+ )
161
+ return x_plus_residual
162
+
163
+
164
+ attn_bias_cache: Dict[Tuple, Any] = {}
165
+
166
+
167
+ def get_attn_bias_and_cat(x_list, branges=None):
168
+ """
169
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
170
+ """
171
+ batch_sizes = (
172
+ [b.shape[0] for b in branges]
173
+ if branges is not None
174
+ else [x.shape[0] for x in x_list]
175
+ )
176
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
177
+ if all_shapes not in attn_bias_cache.keys():
178
+ seqlens = []
179
+ for b, x in zip(batch_sizes, x_list):
180
+ for _ in range(b):
181
+ seqlens.append(x.shape[1])
182
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
183
+ attn_bias._batch_sizes = batch_sizes
184
+ attn_bias_cache[all_shapes] = attn_bias
185
+
186
+ if branges is not None:
187
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
188
+ 1, -1, x_list[0].shape[-1]
189
+ )
190
+ else:
191
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
192
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
193
+
194
+ return attn_bias_cache[all_shapes], cat_tensors
195
+
196
+
197
+ def drop_add_residual_stochastic_depth_list(
198
+ x_list: List[torch.Tensor],
199
+ residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
200
+ sample_drop_ratio: float = 0.0,
201
+ scaling_vector=None,
202
+ ) -> torch.Tensor:
203
+ # 1) generate random set of indices for dropping samples in the batch
204
+ branges_scales = [
205
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
206
+ ]
207
+ branges = [s[0] for s in branges_scales]
208
+ residual_scale_factors = [s[1] for s in branges_scales]
209
+
210
+ # 2) get attention bias and index+concat the tensors
211
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
212
+
213
+ # 3) apply residual_func to get residual, and split the result
214
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
215
+
216
+ outputs = []
217
+ for x, brange, residual, residual_scale_factor in zip(
218
+ x_list, branges, residual_list, residual_scale_factors
219
+ ):
220
+ outputs.append(
221
+ add_residual(
222
+ x, brange, residual, residual_scale_factor, scaling_vector
223
+ ).view_as(x)
224
+ )
225
+ return outputs
226
+
227
+
228
+ class NestedTensorBlock(Block):
229
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
230
+ """
231
+ x_list contains a list of tensors to nest together and run
232
+ """
233
+ assert isinstance(self.attn, MemEffAttention)
234
+
235
+ if self.training and self.sample_drop_ratio > 0.0:
236
+
237
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
238
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
239
+
240
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
241
+ return self.mlp(self.norm2(x))
242
+
243
+ x_list = drop_add_residual_stochastic_depth_list(
244
+ x_list,
245
+ residual_func=attn_residual_func,
246
+ sample_drop_ratio=self.sample_drop_ratio,
247
+ scaling_vector=(
248
+ self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
249
+ ),
250
+ )
251
+ x_list = drop_add_residual_stochastic_depth_list(
252
+ x_list,
253
+ residual_func=ffn_residual_func,
254
+ sample_drop_ratio=self.sample_drop_ratio,
255
+ scaling_vector=(
256
+ self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
257
+ ),
258
+ )
259
+ return x_list
260
+ else:
261
+
262
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
263
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
264
+
265
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
266
+ return self.ls2(self.mlp(self.norm2(x)))
267
+
268
+ attn_bias, x = get_attn_bias_and_cat(x_list)
269
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
270
+ x = x + ffn_residual_func(x)
271
+ return attn_bias.split(x)
272
+
273
+ def forward(self, x_or_x_list):
274
+ if isinstance(x_or_x_list, torch.Tensor):
275
+ return super(NestedTensorBlock, self).forward(x_or_x_list)
276
+ elif isinstance(x_or_x_list, list):
277
+ assert (
278
+ XFORMERS_AVAILABLE
279
+ ), "Please install xFormers for nested tensors usage"
280
+ return self.forward_nested(x_or_x_list)
281
+ else:
282
+ raise AssertionError
unidepth/models/backbones/metadinov2/dino_head.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn.init import trunc_normal_
10
+ from torch.nn.utils import weight_norm
11
+
12
+
13
+ class DINOHead(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim,
17
+ out_dim,
18
+ use_bn=False,
19
+ nlayers=3,
20
+ hidden_dim=2048,
21
+ bottleneck_dim=256,
22
+ mlp_bias=True,
23
+ ):
24
+ super().__init__()
25
+ nlayers = max(nlayers, 1)
26
+ self.mlp = _build_mlp(
27
+ nlayers,
28
+ in_dim,
29
+ bottleneck_dim,
30
+ hidden_dim=hidden_dim,
31
+ use_bn=use_bn,
32
+ bias=mlp_bias,
33
+ )
34
+ self.apply(self._init_weights)
35
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
36
+ self.last_layer.weight_g.data.fill_(1)
37
+
38
+ def _init_weights(self, m):
39
+ if isinstance(m, nn.Linear):
40
+ trunc_normal_(m.weight, std=0.02)
41
+ if isinstance(m, nn.Linear) and m.bias is not None:
42
+ nn.init.constant_(m.bias, 0)
43
+
44
+ def forward(self, x):
45
+ x = self.mlp(x)
46
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
47
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
48
+ x = self.last_layer(x)
49
+ return x
50
+
51
+
52
+ def _build_mlp(
53
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
54
+ ):
55
+ if nlayers == 1:
56
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
57
+ else:
58
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
59
+ if use_bn:
60
+ layers.append(nn.BatchNorm1d(hidden_dim))
61
+ layers.append(nn.GELU())
62
+ for _ in range(nlayers - 2):
63
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
64
+ if use_bn:
65
+ layers.append(nn.BatchNorm1d(hidden_dim))
66
+ layers.append(nn.GELU())
67
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
68
+ return nn.Sequential(*layers)
unidepth/models/backbones/metadinov2/drop_path.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ import torch.nn as nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (
20
+ x.ndim - 1
21
+ ) # work with diff dim tensors, not just 2D ConvNets
22
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
23
+ if keep_prob > 0.0:
24
+ random_tensor.div_(keep_prob)
25
+ output = x * random_tensor
26
+ return output
27
+
28
+
29
+ class DropPath(nn.Module):
30
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
31
+
32
+ def __init__(self, drop_prob=None):
33
+ super(DropPath, self).__init__()
34
+ self.drop_prob = drop_prob
35
+
36
+ def forward(self, x):
37
+ return drop_path(x, self.drop_prob, self.training)
unidepth/models/backbones/metadinov2/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma