Spaces:
Sleeping
Sleeping
Upload 90 files
Browse filesAdd Initial files
This view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +104 -0
- configs/config_v1_cnvnxtl.json +24 -0
- configs/config_v1_vitl14.json +23 -0
- configs/config_v2_vitl14.json +32 -0
- configs/config_v2_vits14.json +32 -0
- unidepth/layers/__init__.py +22 -0
- unidepth/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/activation.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/attention.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/convnext.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/layer_scale.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/mlp.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/nystrom_attention.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/positional_encoding.cpython-311.pyc +0 -0
- unidepth/layers/__pycache__/upsample.cpython-311.pyc +0 -0
- unidepth/layers/activation.py +15 -0
- unidepth/layers/attention.py +308 -0
- unidepth/layers/convnext.py +44 -0
- unidepth/layers/drop_path.py +25 -0
- unidepth/layers/layer_scale.py +17 -0
- unidepth/layers/mlp.py +35 -0
- unidepth/layers/nystrom_attention.py +74 -0
- unidepth/layers/positional_encoding.py +227 -0
- unidepth/layers/upsample.py +134 -0
- unidepth/models/__init__.py +7 -0
- unidepth/models/__pycache__/__init__.cpython-311.pyc +0 -0
- unidepth/models/__pycache__/encoder.cpython-311.pyc +0 -0
- unidepth/models/backbones/__init__.py +9 -0
- unidepth/models/backbones/__pycache__/__init__.cpython-311.pyc +0 -0
- unidepth/models/backbones/__pycache__/convnext.cpython-311.pyc +0 -0
- unidepth/models/backbones/__pycache__/convnext2.cpython-311.pyc +0 -0
- unidepth/models/backbones/__pycache__/dinov2.cpython-311.pyc +0 -0
- unidepth/models/backbones/convnext.py +580 -0
- unidepth/models/backbones/convnext2.py +288 -0
- unidepth/models/backbones/dinov2.py +455 -0
- unidepth/models/backbones/metadinov2/__init__.py +12 -0
- unidepth/models/backbones/metadinov2/__pycache__/__init__.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/attention.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/block.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/dino_head.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/drop_path.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/layer_scale.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/mlp.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/patch_embed.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/__pycache__/swiglu_ffn.cpython-311.pyc +0 -0
- unidepth/models/backbones/metadinov2/attention.py +84 -0
- unidepth/models/backbones/metadinov2/block.py +282 -0
- unidepth/models/backbones/metadinov2/dino_head.py +68 -0
- unidepth/models/backbones/metadinov2/drop_path.py +37 -0
- unidepth/models/backbones/metadinov2/layer_scale.py +28 -0
app.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
from unidepth.models import UniDepthV2
|
7 |
+
import os
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import matplotlib
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
# Load model configurations and initialize model
|
14 |
+
def load_model(config_path, model_path, encoder, device):
|
15 |
+
with open(config_path) as f:
|
16 |
+
config = json.load(f)
|
17 |
+
|
18 |
+
model = UniDepthV2(config)
|
19 |
+
model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True)
|
20 |
+
model = model.to(device).eval()
|
21 |
+
|
22 |
+
return model
|
23 |
+
|
24 |
+
# Inference function
|
25 |
+
def depth_estimation(image, model_path, encoder='vits'):
|
26 |
+
try:
|
27 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
28 |
+
config_path = 'configs/config_v2_vits14.json'
|
29 |
+
|
30 |
+
# Ensure model path exists or download if needed
|
31 |
+
if not os.path.exists(model_path):
|
32 |
+
return "Model checkpoint not found. Please upload a valid model path."
|
33 |
+
|
34 |
+
model = load_model(config_path, model_path, encoder, device)
|
35 |
+
|
36 |
+
# Preprocess image
|
37 |
+
rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W
|
38 |
+
predictions = model.infer(rgb)
|
39 |
+
depth = predictions["depth"].squeeze().to('cpu').numpy()
|
40 |
+
|
41 |
+
min_depth = depth.min()
|
42 |
+
max_depth = depth.max()
|
43 |
+
|
44 |
+
depth_normalized = (depth - min_depth) / (max_depth - min_depth)
|
45 |
+
|
46 |
+
# Apply colormap
|
47 |
+
cmap = matplotlib.colormaps.get_cmap('Spectral')
|
48 |
+
depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8)
|
49 |
+
|
50 |
+
# Create a figure and axis for the colorbar
|
51 |
+
fig, ax = plt.subplots(figsize=(6, 0.4))
|
52 |
+
fig.subplots_adjust(bottom=0.5)
|
53 |
+
|
54 |
+
# Create a colorbar
|
55 |
+
norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth)
|
56 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
57 |
+
sm.set_array([])
|
58 |
+
cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)')
|
59 |
+
|
60 |
+
# Save the colorbar to a BytesIO object
|
61 |
+
from io import BytesIO
|
62 |
+
buf = BytesIO()
|
63 |
+
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
|
64 |
+
plt.close(fig)
|
65 |
+
buf.seek(0)
|
66 |
+
|
67 |
+
# Open the colorbar image
|
68 |
+
colorbar_img = Image.open(buf)
|
69 |
+
|
70 |
+
# Create a new image with space for the colorbar
|
71 |
+
new_height = depth_color.shape[0] + colorbar_img.size[1]
|
72 |
+
new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255))
|
73 |
+
|
74 |
+
# Paste the depth image and colorbar
|
75 |
+
new_img.paste(Image.fromarray(depth_color), (0, 0))
|
76 |
+
new_img.paste(colorbar_img, (0, depth_color.shape[0]))
|
77 |
+
|
78 |
+
return new_img
|
79 |
+
|
80 |
+
|
81 |
+
except Exception as e:
|
82 |
+
return f"Error occurred: {str(e)}"
|
83 |
+
|
84 |
+
# Gradio Interface
|
85 |
+
def main():
|
86 |
+
iface = gr.Interface(
|
87 |
+
fn=depth_estimation,
|
88 |
+
inputs=[
|
89 |
+
gr.Image(type="numpy", label="Input Image"),
|
90 |
+
gr.Textbox(value='checkpoint/latest.pth', label='Model Path'),
|
91 |
+
gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
|
92 |
+
],
|
93 |
+
outputs=[
|
94 |
+
gr.Image(type="pil", label="Predicted Depth")
|
95 |
+
],
|
96 |
+
title="Depth Anything V2 Metric Depth Estimation",
|
97 |
+
description="Upload an image to get its estimated depth map using Depth Anything V2.",
|
98 |
+
)
|
99 |
+
|
100 |
+
iface.launch()
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
main()
|
configs/config_v1_cnvnxtl.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"generic": {
|
3 |
+
"seed": 13
|
4 |
+
},
|
5 |
+
"training": {
|
6 |
+
},
|
7 |
+
"data": {
|
8 |
+
"image_shape": [462, 616]
|
9 |
+
},
|
10 |
+
"model": {
|
11 |
+
"name": "UniDepthV1",
|
12 |
+
"num_heads": 8,
|
13 |
+
"expansion": 4,
|
14 |
+
"pixel_decoder": {
|
15 |
+
"hidden_dim": 512,
|
16 |
+
"depths": [3, 2, 1],
|
17 |
+
"dropout": 0.0
|
18 |
+
},
|
19 |
+
"pixel_encoder": {
|
20 |
+
"name": "convnext_large",
|
21 |
+
"pretrained": null
|
22 |
+
}
|
23 |
+
}
|
24 |
+
}
|
configs/config_v1_vitl14.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"generic": {
|
3 |
+
"seed": 13
|
4 |
+
},
|
5 |
+
"training": {},
|
6 |
+
"data": {
|
7 |
+
"image_shape": [462, 616]
|
8 |
+
},
|
9 |
+
"model": {
|
10 |
+
"name": "UniDepthV1",
|
11 |
+
"num_heads": 8,
|
12 |
+
"expansion": 4,
|
13 |
+
"pixel_decoder": {
|
14 |
+
"hidden_dim": 512,
|
15 |
+
"depths": [3, 2, 1],
|
16 |
+
"dropout": 0.0
|
17 |
+
},
|
18 |
+
"pixel_encoder": {
|
19 |
+
"name": "dinov2_vitl14",
|
20 |
+
"pretrained": null
|
21 |
+
}
|
22 |
+
}
|
23 |
+
}
|
configs/config_v2_vitl14.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"generic": {
|
3 |
+
"seed": 13,
|
4 |
+
"deterministic": true
|
5 |
+
},
|
6 |
+
"training": {},
|
7 |
+
"data": {
|
8 |
+
"image_shape": [420, 560],
|
9 |
+
"shape_constraints": {
|
10 |
+
"ratio_bounds": [0.66, 2.0],
|
11 |
+
"pixels_bounds": [1400, 2400],
|
12 |
+
"patch_size": 14
|
13 |
+
}
|
14 |
+
},
|
15 |
+
"model": {
|
16 |
+
"name": "UniDepthV2",
|
17 |
+
"num_heads": 8,
|
18 |
+
"expansion": 4,
|
19 |
+
"pixel_decoder": {
|
20 |
+
"hidden_dim": 512,
|
21 |
+
"depths": [6, 0, 0],
|
22 |
+
"dropout": 0.0
|
23 |
+
},
|
24 |
+
"pixel_encoder": {
|
25 |
+
"name": "dinov2_vitl14",
|
26 |
+
"pretrained": null,
|
27 |
+
"use_norm": true,
|
28 |
+
"stacking_fn": "last",
|
29 |
+
"output_idx": [21,22,23,24]
|
30 |
+
}
|
31 |
+
}
|
32 |
+
}
|
configs/config_v2_vits14.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"generic": {
|
3 |
+
"seed": 13,
|
4 |
+
"deterministic": true
|
5 |
+
},
|
6 |
+
"training": {},
|
7 |
+
"data": {
|
8 |
+
"image_shape": [420, 560],
|
9 |
+
"shape_constraints": {
|
10 |
+
"ratio_bounds": [0.66, 2.0],
|
11 |
+
"pixels_bounds": [1400, 2400],
|
12 |
+
"patch_size": 14
|
13 |
+
}
|
14 |
+
},
|
15 |
+
"model": {
|
16 |
+
"name": "UniDepthV2",
|
17 |
+
"num_heads": 8,
|
18 |
+
"expansion": 4,
|
19 |
+
"pixel_decoder": {
|
20 |
+
"hidden_dim": 512,
|
21 |
+
"depths": [6, 0, 0],
|
22 |
+
"dropout": 0.0
|
23 |
+
},
|
24 |
+
"pixel_encoder": {
|
25 |
+
"name": "dinov2_vits14",
|
26 |
+
"pretrained": null,
|
27 |
+
"use_norm": true,
|
28 |
+
"stacking_fn": "last",
|
29 |
+
"output_idx": [9,10,11,12]
|
30 |
+
}
|
31 |
+
}
|
32 |
+
}
|
unidepth/layers/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .activation import GEGLU, SwiGLU
|
2 |
+
from .attention import AttentionBlock, AttentionDecoderBlock
|
3 |
+
from .convnext import CvnxtBlock
|
4 |
+
from .mlp import MLP
|
5 |
+
from .nystrom_attention import NystromBlock
|
6 |
+
from .positional_encoding import PositionEmbeddingSine
|
7 |
+
from .upsample import (ConvUpsample, ConvUpsampleShuffle,
|
8 |
+
ConvUpsampleShuffleResidual)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"SwiGLU",
|
12 |
+
"GEGLU",
|
13 |
+
"CvnxtBlock",
|
14 |
+
"AttentionBlock",
|
15 |
+
"NystromBlock",
|
16 |
+
"PositionEmbeddingSine",
|
17 |
+
"ConvUpsample",
|
18 |
+
"MLP",
|
19 |
+
"ConvUpsampleShuffle",
|
20 |
+
"AttentionDecoderBlock",
|
21 |
+
"ConvUpsampleShuffleResidual",
|
22 |
+
]
|
unidepth/layers/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (845 Bytes). View file
|
|
unidepth/layers/__pycache__/activation.cpython-311.pyc
ADDED
Binary file (1.38 kB). View file
|
|
unidepth/layers/__pycache__/attention.cpython-311.pyc
ADDED
Binary file (14.5 kB). View file
|
|
unidepth/layers/__pycache__/convnext.cpython-311.pyc
ADDED
Binary file (2.41 kB). View file
|
|
unidepth/layers/__pycache__/layer_scale.cpython-311.pyc
ADDED
Binary file (1.53 kB). View file
|
|
unidepth/layers/__pycache__/mlp.cpython-311.pyc
ADDED
Binary file (2.41 kB). View file
|
|
unidepth/layers/__pycache__/nystrom_attention.cpython-311.pyc
ADDED
Binary file (3.57 kB). View file
|
|
unidepth/layers/__pycache__/positional_encoding.cpython-311.pyc
ADDED
Binary file (16.6 kB). View file
|
|
unidepth/layers/__pycache__/upsample.cpython-311.pyc
ADDED
Binary file (6.19 kB). View file
|
|
unidepth/layers/activation.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class SwiGLU(nn.Module):
|
7 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
8 |
+
x, gates = x.chunk(2, dim=-1)
|
9 |
+
return x * F.silu(gates)
|
10 |
+
|
11 |
+
|
12 |
+
class GEGLU(nn.Module):
|
13 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
14 |
+
x, gates = x.chunk(2, dim=-1)
|
15 |
+
return x * F.gelu(gates)
|
unidepth/layers/attention.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Luigi Piccinelli
|
3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
4 |
+
"""
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
from .layer_scale import LayerScale
|
14 |
+
from .mlp import MLP
|
15 |
+
|
16 |
+
|
17 |
+
class SimpleAttention(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
dim: int,
|
21 |
+
num_heads: int = 4,
|
22 |
+
dropout: float = 0.0,
|
23 |
+
cosine: bool = False,
|
24 |
+
context_dim: int | None = None,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.dropout = dropout
|
28 |
+
self.num_heads = num_heads
|
29 |
+
self.hidden_dim = dim
|
30 |
+
context_dim = context_dim or dim
|
31 |
+
|
32 |
+
self.kv = nn.Linear(context_dim, dim * 2, bias=False)
|
33 |
+
self.q = nn.Linear(dim, dim, bias=False)
|
34 |
+
self.norm_attnx = nn.LayerNorm(dim)
|
35 |
+
self.norm_attnctx = nn.LayerNorm(context_dim)
|
36 |
+
self.cosine = cosine
|
37 |
+
self.out = nn.Linear(dim, dim)
|
38 |
+
|
39 |
+
def forward(
|
40 |
+
self,
|
41 |
+
x: torch.Tensor,
|
42 |
+
attn_bias: torch.Tensor | None = None,
|
43 |
+
context: torch.Tensor | None = None,
|
44 |
+
pos_embed: torch.Tensor | None = None,
|
45 |
+
pos_embed_context: torch.Tensor | None = None,
|
46 |
+
rope: nn.Module | None = None,
|
47 |
+
) -> torch.Tensor:
|
48 |
+
context = x if context is None else context
|
49 |
+
x = self.norm_attnx(x)
|
50 |
+
context = self.norm_attnctx(context)
|
51 |
+
k, v = rearrange(
|
52 |
+
self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
|
53 |
+
).unbind(dim=-1)
|
54 |
+
q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
|
55 |
+
|
56 |
+
if rope is not None:
|
57 |
+
q = rope(q)
|
58 |
+
k = rope(k)
|
59 |
+
else:
|
60 |
+
if pos_embed is not None:
|
61 |
+
pos_embed = rearrange(
|
62 |
+
pos_embed, "b n (h d) -> b h n d", h=self.num_heads
|
63 |
+
)
|
64 |
+
q = q + pos_embed
|
65 |
+
if pos_embed_context is not None:
|
66 |
+
pos_embed_context = rearrange(
|
67 |
+
pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
|
68 |
+
)
|
69 |
+
k = k + pos_embed_context
|
70 |
+
|
71 |
+
if self.cosine:
|
72 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
73 |
+
x = F.scaled_dot_product_attention(
|
74 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
75 |
+
)
|
76 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
77 |
+
x = self.out(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class AttentionBlock(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim: int,
|
85 |
+
num_heads: int = 4,
|
86 |
+
expansion: int = 4,
|
87 |
+
dropout: float = 0.0,
|
88 |
+
cosine: bool = False,
|
89 |
+
gated: bool = False,
|
90 |
+
layer_scale: float = 1.0,
|
91 |
+
context_dim: int | None = None,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
self.dropout = dropout
|
95 |
+
self.num_heads = num_heads
|
96 |
+
self.hidden_dim = dim
|
97 |
+
context_dim = context_dim or dim
|
98 |
+
self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
|
99 |
+
self.kv = nn.Linear(context_dim, dim * 2)
|
100 |
+
self.q = nn.Linear(dim, dim)
|
101 |
+
self.norm_attnx = nn.LayerNorm(dim)
|
102 |
+
self.norm_attnctx = nn.LayerNorm(context_dim)
|
103 |
+
self.cosine = cosine
|
104 |
+
self.out = nn.Linear(dim, dim)
|
105 |
+
self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
106 |
+
self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
107 |
+
|
108 |
+
def attn(
|
109 |
+
self,
|
110 |
+
x: torch.Tensor,
|
111 |
+
attn_bias: torch.Tensor | None = None,
|
112 |
+
context: torch.Tensor | None = None,
|
113 |
+
pos_embed: torch.Tensor | None = None,
|
114 |
+
pos_embed_context: torch.Tensor | None = None,
|
115 |
+
rope: nn.Module | None = None,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
x = self.norm_attnx(x)
|
118 |
+
context = self.norm_attnctx(context)
|
119 |
+
k, v = rearrange(
|
120 |
+
self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
|
121 |
+
).unbind(dim=-1)
|
122 |
+
q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
|
123 |
+
|
124 |
+
if rope is not None:
|
125 |
+
q = rope(q)
|
126 |
+
k = rope(k)
|
127 |
+
else:
|
128 |
+
if pos_embed is not None:
|
129 |
+
pos_embed = rearrange(
|
130 |
+
pos_embed, "b n (h d) -> b h n d", h=self.num_heads
|
131 |
+
)
|
132 |
+
q = q + pos_embed
|
133 |
+
if pos_embed_context is not None:
|
134 |
+
pos_embed_context = rearrange(
|
135 |
+
pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
|
136 |
+
)
|
137 |
+
k = k + pos_embed_context
|
138 |
+
|
139 |
+
if self.cosine:
|
140 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
141 |
+
|
142 |
+
x = F.scaled_dot_product_attention(
|
143 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
144 |
+
)
|
145 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
146 |
+
x = self.out(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
def forward(
|
150 |
+
self,
|
151 |
+
x: torch.Tensor,
|
152 |
+
attn_bias: torch.Tensor | None = None,
|
153 |
+
context: torch.Tensor | None = None,
|
154 |
+
pos_embed: torch.Tensor | None = None,
|
155 |
+
pos_embed_context: torch.Tensor | None = None,
|
156 |
+
rope: nn.Module | None = None,
|
157 |
+
) -> torch.Tensor:
|
158 |
+
context = x if context is None else context
|
159 |
+
x = (
|
160 |
+
self.ls1(
|
161 |
+
self.attn(
|
162 |
+
x,
|
163 |
+
rope=rope,
|
164 |
+
attn_bias=attn_bias,
|
165 |
+
context=context,
|
166 |
+
pos_embed=pos_embed,
|
167 |
+
pos_embed_context=pos_embed_context,
|
168 |
+
)
|
169 |
+
)
|
170 |
+
+ x
|
171 |
+
)
|
172 |
+
x = self.ls2(self.mlp(x)) + x
|
173 |
+
return x
|
174 |
+
|
175 |
+
|
176 |
+
class AttentionDecoderBlock(nn.Module):
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
dim: int,
|
180 |
+
num_heads: int = 4,
|
181 |
+
expansion: int = 4,
|
182 |
+
dropout: float = 0.0,
|
183 |
+
cosine: bool = False,
|
184 |
+
gated: bool = False,
|
185 |
+
layer_scale: float = 1.0,
|
186 |
+
context_dim: int | None = None,
|
187 |
+
single_head_ca: bool = True,
|
188 |
+
):
|
189 |
+
super().__init__()
|
190 |
+
self.dropout = dropout
|
191 |
+
self.num_heads = num_heads
|
192 |
+
self.hidden_dim = dim
|
193 |
+
self.single_head_ca = single_head_ca
|
194 |
+
context_dim = context_dim or dim
|
195 |
+
self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
|
196 |
+
self.kv_ca = nn.Linear(context_dim, dim * 2)
|
197 |
+
self.q_ca = nn.Linear(dim, dim)
|
198 |
+
self.kv_sa = nn.Linear(dim, dim * 2)
|
199 |
+
self.q_sa = nn.Linear(dim, dim)
|
200 |
+
self.norm_x_sa = nn.LayerNorm(dim)
|
201 |
+
self.norm_x_ca = nn.LayerNorm(dim)
|
202 |
+
self.norm_ctx_ca = nn.LayerNorm(context_dim)
|
203 |
+
self.cosine = cosine
|
204 |
+
self.out_ca = nn.Linear(dim, dim)
|
205 |
+
self.out_sa = nn.Linear(dim, dim)
|
206 |
+
self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
207 |
+
self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
208 |
+
self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
209 |
+
|
210 |
+
def cross_attn(
|
211 |
+
self,
|
212 |
+
x: torch.Tensor,
|
213 |
+
attn_bias: torch.Tensor | None = None,
|
214 |
+
context: torch.Tensor | None = None,
|
215 |
+
pos_embed: torch.Tensor | None = None,
|
216 |
+
pos_embed_context: torch.Tensor | None = None,
|
217 |
+
rope: nn.Module | None = None,
|
218 |
+
) -> torch.Tensor:
|
219 |
+
num_heads = 1 if self.single_head_ca else self.num_heads
|
220 |
+
x = self.norm_x_ca(x)
|
221 |
+
context = self.norm_ctx_ca(context)
|
222 |
+
k, v = rearrange(
|
223 |
+
self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2
|
224 |
+
).unbind(dim=-1)
|
225 |
+
q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads)
|
226 |
+
|
227 |
+
if rope is not None:
|
228 |
+
q = rope(q)
|
229 |
+
k = rope(k)
|
230 |
+
else:
|
231 |
+
if pos_embed is not None:
|
232 |
+
pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads)
|
233 |
+
q = q + pos_embed
|
234 |
+
if pos_embed_context is not None:
|
235 |
+
pos_embed_context = rearrange(
|
236 |
+
pos_embed_context, "b n (h d) -> b h n d", h=num_heads
|
237 |
+
)
|
238 |
+
k = k + pos_embed_context
|
239 |
+
|
240 |
+
if self.cosine:
|
241 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
242 |
+
x = F.scaled_dot_product_attention(
|
243 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
244 |
+
)
|
245 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
246 |
+
x = self.out_ca(x)
|
247 |
+
return x
|
248 |
+
|
249 |
+
def self_attn(
|
250 |
+
self,
|
251 |
+
x: torch.Tensor,
|
252 |
+
attn_bias: torch.Tensor | None = None,
|
253 |
+
pos_embed: torch.Tensor | None = None,
|
254 |
+
rope: nn.Module | None = None,
|
255 |
+
) -> torch.Tensor:
|
256 |
+
x = self.norm_x_sa(x)
|
257 |
+
k, v = rearrange(
|
258 |
+
self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
|
259 |
+
).unbind(dim=-1)
|
260 |
+
q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads)
|
261 |
+
|
262 |
+
if rope is not None:
|
263 |
+
q = rope(q)
|
264 |
+
k = rope(k)
|
265 |
+
elif pos_embed is not None:
|
266 |
+
pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
|
267 |
+
q = q + pos_embed
|
268 |
+
|
269 |
+
if self.cosine:
|
270 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
271 |
+
x = F.scaled_dot_product_attention(
|
272 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
273 |
+
)
|
274 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
275 |
+
x = self.out_sa(x)
|
276 |
+
return x
|
277 |
+
|
278 |
+
def forward(
|
279 |
+
self,
|
280 |
+
x: torch.Tensor,
|
281 |
+
attn_bias: torch.Tensor | None = None,
|
282 |
+
context: torch.Tensor | None = None,
|
283 |
+
pos_embed: torch.Tensor | None = None,
|
284 |
+
pos_embed_context: torch.Tensor | None = None,
|
285 |
+
rope: nn.Module | None = None,
|
286 |
+
) -> torch.Tensor:
|
287 |
+
context = x if context is None else context
|
288 |
+
x = (
|
289 |
+
self.ls1(
|
290 |
+
self.cross_attn(
|
291 |
+
x,
|
292 |
+
rope=rope,
|
293 |
+
attn_bias=attn_bias,
|
294 |
+
context=context,
|
295 |
+
pos_embed=pos_embed,
|
296 |
+
pos_embed_context=pos_embed_context,
|
297 |
+
)
|
298 |
+
)
|
299 |
+
+ x
|
300 |
+
)
|
301 |
+
x = (
|
302 |
+
self.ls2(
|
303 |
+
self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed)
|
304 |
+
)
|
305 |
+
+ x
|
306 |
+
)
|
307 |
+
x = self.ls3(self.mlp(x)) + x
|
308 |
+
return x
|
unidepth/layers/convnext.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class CvnxtBlock(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim,
|
9 |
+
kernel_size=7,
|
10 |
+
layer_scale=1.0,
|
11 |
+
expansion=4,
|
12 |
+
dilation=1,
|
13 |
+
padding_mode: str = "zeros",
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
self.dwconv = nn.Conv2d(
|
17 |
+
dim,
|
18 |
+
dim,
|
19 |
+
kernel_size=kernel_size,
|
20 |
+
padding=dilation * (kernel_size - 1) // 2,
|
21 |
+
groups=dim,
|
22 |
+
dilation=dilation,
|
23 |
+
padding_mode=padding_mode,
|
24 |
+
) # depthwise conv
|
25 |
+
self.norm = nn.LayerNorm(dim)
|
26 |
+
self.pwconv1 = nn.Linear(dim, expansion * dim)
|
27 |
+
self.act = nn.GELU()
|
28 |
+
self.pwconv2 = nn.Linear(expansion * dim, dim)
|
29 |
+
self.gamma = (
|
30 |
+
nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
input = x
|
35 |
+
x = self.dwconv(x)
|
36 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
37 |
+
x = self.norm(x)
|
38 |
+
x = self.pwconv1(x)
|
39 |
+
x = self.act(x)
|
40 |
+
x = self.pwconv2(x)
|
41 |
+
|
42 |
+
x = self.gamma * x
|
43 |
+
x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
44 |
+
return x
|
unidepth/layers/drop_path.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False):
|
6 |
+
if drop_prob == 0.0 or not training:
|
7 |
+
return x
|
8 |
+
keep_prob = 1 - drop_prob
|
9 |
+
shape = (x.shape[0],) + (1,) * (
|
10 |
+
x.ndim - 1
|
11 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
12 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
13 |
+
if keep_prob > 0.0:
|
14 |
+
random_tensor.div_(keep_prob)
|
15 |
+
output = x * random_tensor
|
16 |
+
return output
|
17 |
+
|
18 |
+
|
19 |
+
class DropPath(nn.Module):
|
20 |
+
def __init__(self, drop_prob=None):
|
21 |
+
super(DropPath, self).__init__()
|
22 |
+
self.drop_prob = drop_prob
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return drop_path(x, self.drop_prob, self.training)
|
unidepth/layers/layer_scale.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class LayerScale(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim: int,
|
9 |
+
init_values: float | torch.Tensor = 1e-5,
|
10 |
+
inplace: bool = False,
|
11 |
+
) -> None:
|
12 |
+
super().__init__()
|
13 |
+
self.inplace = inplace
|
14 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
15 |
+
|
16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
17 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
unidepth/layers/mlp.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from unidepth.utils.misc import default
|
5 |
+
|
6 |
+
from .activation import SwiGLU
|
7 |
+
|
8 |
+
|
9 |
+
class MLP(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
input_dim: int,
|
13 |
+
expansion: int = 4,
|
14 |
+
dropout: float = 0.0,
|
15 |
+
gated: bool = False,
|
16 |
+
output_dim: int | None = None,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
if gated:
|
20 |
+
expansion = int(expansion * 2 / 3)
|
21 |
+
hidden_dim = int(input_dim * expansion)
|
22 |
+
output_dim = default(output_dim, input_dim)
|
23 |
+
self.norm = nn.LayerNorm(input_dim)
|
24 |
+
self.proj1 = nn.Linear(input_dim, hidden_dim)
|
25 |
+
self.proj2 = nn.Linear(hidden_dim, output_dim)
|
26 |
+
self.act = nn.GELU() if not gated else SwiGLU()
|
27 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
30 |
+
x = self.norm(x)
|
31 |
+
x = self.proj1(x)
|
32 |
+
x = self.act(x)
|
33 |
+
x = self.proj2(x)
|
34 |
+
x = self.dropout(x)
|
35 |
+
return x
|
unidepth/layers/nystrom_attention.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from xformers.components.attention import NystromAttention
|
8 |
+
|
9 |
+
from .attention import AttentionBlock
|
10 |
+
|
11 |
+
|
12 |
+
class NystromBlock(AttentionBlock):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
dim: int,
|
16 |
+
num_heads: int = 4,
|
17 |
+
expansion: int = 4,
|
18 |
+
dropout: float = 0.0,
|
19 |
+
cosine: bool = False,
|
20 |
+
gated: bool = False,
|
21 |
+
layer_scale: float = 1.0,
|
22 |
+
context_dim: int | None = None,
|
23 |
+
):
|
24 |
+
super().__init__(
|
25 |
+
dim=dim,
|
26 |
+
num_heads=num_heads,
|
27 |
+
expansion=expansion,
|
28 |
+
dropout=dropout,
|
29 |
+
cosine=cosine,
|
30 |
+
gated=gated,
|
31 |
+
layer_scale=layer_scale,
|
32 |
+
context_dim=context_dim,
|
33 |
+
)
|
34 |
+
self.attention_fn = NystromAttention(
|
35 |
+
num_landmarks=128, num_heads=num_heads, dropout=dropout
|
36 |
+
)
|
37 |
+
|
38 |
+
def attn(
|
39 |
+
self,
|
40 |
+
x: torch.Tensor,
|
41 |
+
attn_bias: torch.Tensor | None = None,
|
42 |
+
context: torch.Tensor | None = None,
|
43 |
+
pos_embed: torch.Tensor | None = None,
|
44 |
+
pos_embed_context: torch.Tensor | None = None,
|
45 |
+
rope: nn.Module | None = None,
|
46 |
+
) -> torch.Tensor:
|
47 |
+
x = self.norm_attnx(x)
|
48 |
+
context = self.norm_attnctx(context)
|
49 |
+
k, v = rearrange(
|
50 |
+
self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2
|
51 |
+
).unbind(dim=-1)
|
52 |
+
q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads)
|
53 |
+
|
54 |
+
if rope is not None:
|
55 |
+
q = rope(q)
|
56 |
+
k = rope(k)
|
57 |
+
else:
|
58 |
+
if pos_embed is not None:
|
59 |
+
pos_embed = rearrange(
|
60 |
+
pos_embed, "b n (h d) -> b n h d", h=self.num_heads
|
61 |
+
)
|
62 |
+
q = q + pos_embed
|
63 |
+
if pos_embed_context is not None:
|
64 |
+
pos_embed_context = rearrange(
|
65 |
+
pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads
|
66 |
+
)
|
67 |
+
k = k + pos_embed_context
|
68 |
+
|
69 |
+
if self.cosine:
|
70 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
71 |
+
x = self.attention_fn(q, k, v, key_padding_mask=attn_bias)
|
72 |
+
x = rearrange(x, "b n h d -> b n (h d)")
|
73 |
+
x = self.out(x)
|
74 |
+
return x
|
unidepth/layers/positional_encoding.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Luigi Piccinelli
|
3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
4 |
+
"""
|
5 |
+
|
6 |
+
from math import pi
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
|
14 |
+
class PositionEmbeddingSine(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.num_pos_feats = num_pos_feats
|
20 |
+
self.temperature = temperature
|
21 |
+
self.normalize = normalize
|
22 |
+
if scale is not None and normalize is False:
|
23 |
+
raise ValueError("normalize should be True if scale is passed")
|
24 |
+
if scale is None:
|
25 |
+
scale = 2 * pi
|
26 |
+
self.scale = scale
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
30 |
+
) -> torch.Tensor:
|
31 |
+
if mask is None:
|
32 |
+
mask = torch.zeros(
|
33 |
+
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
34 |
+
)
|
35 |
+
not_mask = ~mask
|
36 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
37 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
38 |
+
if self.normalize:
|
39 |
+
eps = 1e-6
|
40 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
41 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
42 |
+
|
43 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
44 |
+
dim_t = self.temperature ** (
|
45 |
+
2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
|
46 |
+
)
|
47 |
+
|
48 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
49 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
50 |
+
pos_x = torch.stack(
|
51 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
52 |
+
).flatten(3)
|
53 |
+
pos_y = torch.stack(
|
54 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
55 |
+
).flatten(3)
|
56 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
57 |
+
return pos
|
58 |
+
|
59 |
+
def __repr__(self, _repr_indent=4):
|
60 |
+
head = "Positional encoding " + self.__class__.__name__
|
61 |
+
body = [
|
62 |
+
"num_pos_feats: {}".format(self.num_pos_feats),
|
63 |
+
"temperature: {}".format(self.temperature),
|
64 |
+
"normalize: {}".format(self.normalize),
|
65 |
+
"scale: {}".format(self.scale),
|
66 |
+
]
|
67 |
+
# _repr_indent = 4
|
68 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
69 |
+
return "\n".join(lines)
|
70 |
+
|
71 |
+
|
72 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
73 |
+
def __init__(self, dim):
|
74 |
+
super().__init__()
|
75 |
+
assert (dim % 2) == 0
|
76 |
+
half_dim = dim // 2
|
77 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = rearrange(x, "b -> b 1")
|
81 |
+
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
|
82 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
83 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
84 |
+
return fouriered
|
85 |
+
|
86 |
+
|
87 |
+
def generate_fourier_features(x, max_freq=64, num_bands=16):
|
88 |
+
x = x.unsqueeze(-1)
|
89 |
+
device, dtype, orig_x = x.device, x.dtype, x
|
90 |
+
|
91 |
+
scales = torch.linspace(
|
92 |
+
-max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype
|
93 |
+
)
|
94 |
+
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
|
95 |
+
|
96 |
+
x = x * scales * pi
|
97 |
+
x = torch.cat([x.sin(), x.cos()], dim=-1)
|
98 |
+
x = torch.cat((x, orig_x), dim=-1)
|
99 |
+
return x.flatten(-2)
|
100 |
+
|
101 |
+
|
102 |
+
def broadcat(tensors, dim=-1):
|
103 |
+
num_tensors = len(tensors)
|
104 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
105 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
106 |
+
shape_len = list(shape_lens)[0]
|
107 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
108 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
109 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
110 |
+
assert all(
|
111 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
112 |
+
), "invalid dimensions for broadcastable concatentation"
|
113 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
114 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
115 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
116 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
117 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
118 |
+
return torch.cat(tensors, dim=dim)
|
119 |
+
|
120 |
+
|
121 |
+
def rotate_half(x):
|
122 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
123 |
+
x1, x2 = x.unbind(dim=-1)
|
124 |
+
x = torch.stack((-x2, x1), dim=-1)
|
125 |
+
return rearrange(x, "... d r -> ... (d r)")
|
126 |
+
|
127 |
+
|
128 |
+
class VisionRotaryEmbedding(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
dim,
|
132 |
+
pt_seq_len,
|
133 |
+
ft_seq_len=None,
|
134 |
+
custom_freqs=None,
|
135 |
+
freqs_for="lang",
|
136 |
+
theta=10000,
|
137 |
+
max_freq=10,
|
138 |
+
num_freqs=1,
|
139 |
+
):
|
140 |
+
super().__init__()
|
141 |
+
if custom_freqs:
|
142 |
+
freqs = custom_freqs
|
143 |
+
elif freqs_for == "lang":
|
144 |
+
freqs = 1.0 / (
|
145 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
146 |
+
)
|
147 |
+
elif freqs_for == "pixel":
|
148 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
149 |
+
elif freqs_for == "constant":
|
150 |
+
freqs = torch.ones(num_freqs).float()
|
151 |
+
else:
|
152 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
153 |
+
|
154 |
+
if ft_seq_len is None:
|
155 |
+
ft_seq_len = pt_seq_len
|
156 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
157 |
+
|
158 |
+
freqs_h = torch.einsum("..., f -> ... f", t, freqs)
|
159 |
+
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
160 |
+
|
161 |
+
freqs_w = torch.einsum("..., f -> ... f", t, freqs)
|
162 |
+
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
163 |
+
|
164 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
|
165 |
+
|
166 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
167 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
168 |
+
|
169 |
+
print("======== shape of rope freq", self.freqs_cos.shape, "========")
|
170 |
+
|
171 |
+
def forward(self, t, start_index=0):
|
172 |
+
rot_dim = self.freqs_cos.shape[-1]
|
173 |
+
end_index = start_index + rot_dim
|
174 |
+
assert (
|
175 |
+
rot_dim <= t.shape[-1]
|
176 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
177 |
+
t_left, t, t_right = (
|
178 |
+
t[..., :start_index],
|
179 |
+
t[..., start_index:end_index],
|
180 |
+
t[..., end_index:],
|
181 |
+
)
|
182 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
183 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
184 |
+
|
185 |
+
|
186 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
dim,
|
190 |
+
pt_seq_len,
|
191 |
+
ft_seq_len=None,
|
192 |
+
custom_freqs=None,
|
193 |
+
freqs_for="lang",
|
194 |
+
theta=10000,
|
195 |
+
max_freq=10,
|
196 |
+
num_freqs=1,
|
197 |
+
):
|
198 |
+
super().__init__()
|
199 |
+
if custom_freqs:
|
200 |
+
freqs = custom_freqs
|
201 |
+
elif freqs_for == "lang":
|
202 |
+
freqs = 1.0 / (
|
203 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
204 |
+
)
|
205 |
+
elif freqs_for == "pixel":
|
206 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
207 |
+
elif freqs_for == "constant":
|
208 |
+
freqs = torch.ones(num_freqs).float()
|
209 |
+
else:
|
210 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
211 |
+
|
212 |
+
if ft_seq_len is None:
|
213 |
+
ft_seq_len = pt_seq_len
|
214 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
215 |
+
|
216 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
217 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
218 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
219 |
+
|
220 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
221 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
222 |
+
|
223 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
224 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
225 |
+
|
226 |
+
def forward(self, t):
|
227 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
unidepth/layers/upsample.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Luigi Piccinelli
|
3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from .convnext import CvnxtBlock
|
11 |
+
|
12 |
+
|
13 |
+
class ConvUpsample(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
hidden_dim,
|
17 |
+
num_layers: int = 2,
|
18 |
+
expansion: int = 4,
|
19 |
+
layer_scale: float = 1.0,
|
20 |
+
kernel_size: int = 7,
|
21 |
+
**kwargs,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.convs = nn.ModuleList([])
|
25 |
+
for _ in range(num_layers):
|
26 |
+
self.convs.append(
|
27 |
+
CvnxtBlock(
|
28 |
+
hidden_dim,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
expansion=expansion,
|
31 |
+
layer_scale=layer_scale,
|
32 |
+
)
|
33 |
+
)
|
34 |
+
self.up = nn.Sequential(
|
35 |
+
nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
|
36 |
+
nn.UpsamplingBilinear2d(scale_factor=2),
|
37 |
+
nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1),
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor):
|
41 |
+
for conv in self.convs:
|
42 |
+
x = conv(x)
|
43 |
+
x = self.up(x)
|
44 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class ConvUpsampleShuffle(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
hidden_dim,
|
52 |
+
num_layers: int = 2,
|
53 |
+
expansion: int = 4,
|
54 |
+
layer_scale: float = 1.0,
|
55 |
+
kernel_size: int = 7,
|
56 |
+
**kwargs,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.convs = nn.ModuleList([])
|
60 |
+
for _ in range(num_layers):
|
61 |
+
self.convs.append(
|
62 |
+
CvnxtBlock(
|
63 |
+
hidden_dim,
|
64 |
+
kernel_size=kernel_size,
|
65 |
+
expansion=expansion,
|
66 |
+
layer_scale=layer_scale,
|
67 |
+
)
|
68 |
+
)
|
69 |
+
self.up = nn.Sequential(
|
70 |
+
nn.PixelShuffle(2),
|
71 |
+
nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor):
|
75 |
+
for conv in self.convs:
|
76 |
+
x = conv(x)
|
77 |
+
x = self.up(x)
|
78 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class ConvUpsampleShuffleResidual(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
hidden_dim,
|
86 |
+
num_layers: int = 2,
|
87 |
+
expansion: int = 4,
|
88 |
+
layer_scale: float = 1.0,
|
89 |
+
kernel_size: int = 7,
|
90 |
+
padding_mode: str = "zeros",
|
91 |
+
**kwargs,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
self.convs = nn.ModuleList([])
|
95 |
+
for _ in range(num_layers):
|
96 |
+
self.convs.append(
|
97 |
+
CvnxtBlock(
|
98 |
+
hidden_dim,
|
99 |
+
kernel_size=kernel_size,
|
100 |
+
expansion=expansion,
|
101 |
+
layer_scale=layer_scale,
|
102 |
+
padding_mode=padding_mode,
|
103 |
+
)
|
104 |
+
)
|
105 |
+
self.up = nn.Sequential(
|
106 |
+
nn.PixelShuffle(2),
|
107 |
+
nn.Conv2d(
|
108 |
+
hidden_dim // 4,
|
109 |
+
hidden_dim // 4,
|
110 |
+
kernel_size=7,
|
111 |
+
padding=3,
|
112 |
+
padding_mode=padding_mode,
|
113 |
+
groups=hidden_dim // 4,
|
114 |
+
),
|
115 |
+
nn.ReLU(),
|
116 |
+
nn.Conv2d(
|
117 |
+
hidden_dim // 4,
|
118 |
+
hidden_dim // 2,
|
119 |
+
kernel_size=3,
|
120 |
+
padding=1,
|
121 |
+
padding_mode=padding_mode,
|
122 |
+
),
|
123 |
+
)
|
124 |
+
self.residual = nn.Sequential(
|
125 |
+
nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
|
126 |
+
nn.UpsamplingBilinear2d(scale_factor=2),
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, x: torch.Tensor):
|
130 |
+
for conv in self.convs:
|
131 |
+
x = conv(x)
|
132 |
+
x = self.up(x) + self.residual(x)
|
133 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
134 |
+
return x
|
unidepth/models/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .unidepthv1 import UniDepthV1
|
2 |
+
from .unidepthv2 import UniDepthV2
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"UniDepthV1",
|
6 |
+
"UniDepthV2",
|
7 |
+
]
|
unidepth/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (319 Bytes). View file
|
|
unidepth/models/__pycache__/encoder.cpython-311.pyc
ADDED
Binary file (9.56 kB). View file
|
|
unidepth/models/backbones/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .convnext import ConvNeXt
|
2 |
+
from .convnext2 import ConvNeXtV2
|
3 |
+
from .dinov2 import _make_dinov2_model
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"ConvNeXt",
|
7 |
+
"ConvNeXtV2",
|
8 |
+
"_make_dinov2_model",
|
9 |
+
]
|
unidepth/models/backbones/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (398 Bytes). View file
|
|
unidepth/models/backbones/__pycache__/convnext.cpython-311.pyc
ADDED
Binary file (28.2 kB). View file
|
|
unidepth/models/backbones/__pycache__/convnext2.cpython-311.pyc
ADDED
Binary file (17.4 kB). View file
|
|
unidepth/models/backbones/__pycache__/dinov2.cpython-311.pyc
ADDED
Binary file (22.4 kB). View file
|
|
unidepth/models/backbones/convnext.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from functools import partial
|
3 |
+
from typing import Callable, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from timm.layers import (AvgPool2dSame, DropPath, GlobalResponseNormMlp,
|
8 |
+
LayerNorm, LayerNorm2d, Mlp, create_conv2d,
|
9 |
+
get_act_layer, make_divisible, to_ntuple,
|
10 |
+
trunc_normal_)
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
|
14 |
+
def get_num_layer_for_convnext(var_name):
|
15 |
+
"""
|
16 |
+
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
|
17 |
+
consecutive blocks, including possible neighboring downsample layers;
|
18 |
+
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
|
19 |
+
"""
|
20 |
+
if var_name.startswith("downsample_layers"):
|
21 |
+
stage_id = int(var_name.split(".")[1])
|
22 |
+
if stage_id == 0:
|
23 |
+
layer_id = 0
|
24 |
+
elif stage_id == 1 or stage_id == 2:
|
25 |
+
layer_id = stage_id + 1
|
26 |
+
elif stage_id == 3:
|
27 |
+
layer_id = 12
|
28 |
+
|
29 |
+
elif var_name.startswith("stages"):
|
30 |
+
stage_id = int(var_name.split(".")[1])
|
31 |
+
block_id = int(var_name.split(".")[3])
|
32 |
+
if stage_id == 0 or stage_id == 1:
|
33 |
+
layer_id = stage_id + 1
|
34 |
+
elif stage_id == 2:
|
35 |
+
layer_id = 3 + block_id // 3
|
36 |
+
elif stage_id == 3:
|
37 |
+
layer_id = 12
|
38 |
+
|
39 |
+
elif var_name.startswith("stem"):
|
40 |
+
return 0
|
41 |
+
else:
|
42 |
+
layer_id = 12
|
43 |
+
return layer_id + 1
|
44 |
+
|
45 |
+
|
46 |
+
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None):
|
47 |
+
parameter_group_names = {}
|
48 |
+
parameter_group_vars = {}
|
49 |
+
skip = set()
|
50 |
+
if skip_list is not None:
|
51 |
+
skip = skip_list
|
52 |
+
if hasattr(model, "no_weight_decay"):
|
53 |
+
skip.update(model.no_weight_decay())
|
54 |
+
num_layers = 12
|
55 |
+
layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
|
56 |
+
for name, param in model.named_parameters():
|
57 |
+
if not param.requires_grad:
|
58 |
+
continue # frozen weights
|
59 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip:
|
60 |
+
group_name = "no_decay"
|
61 |
+
this_wd = 0.0
|
62 |
+
else:
|
63 |
+
group_name = "decay"
|
64 |
+
this_wd = wd
|
65 |
+
|
66 |
+
layer_id = get_num_layer_for_convnext(name)
|
67 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
68 |
+
|
69 |
+
if group_name not in parameter_group_names:
|
70 |
+
scale = layer_scale[layer_id]
|
71 |
+
cur_lr = lr * scale
|
72 |
+
|
73 |
+
parameter_group_names[group_name] = {
|
74 |
+
"weight_decay": this_wd,
|
75 |
+
"weight_decay_init": this_wd,
|
76 |
+
"weight_decay_base": this_wd,
|
77 |
+
"params": [],
|
78 |
+
"lr_init": cur_lr,
|
79 |
+
"lr_base": lr,
|
80 |
+
"lr": cur_lr,
|
81 |
+
}
|
82 |
+
parameter_group_vars[group_name] = {
|
83 |
+
"weight_decay": this_wd,
|
84 |
+
"weight_decay_init": this_wd,
|
85 |
+
"weight_decay_base": this_wd,
|
86 |
+
"params": [],
|
87 |
+
"lr_init": cur_lr,
|
88 |
+
"lr_base": lr,
|
89 |
+
"lr": cur_lr,
|
90 |
+
}
|
91 |
+
if this_wd == 0.0:
|
92 |
+
parameter_group_names[group_name]["weight_decay_final"] = 0.0
|
93 |
+
parameter_group_vars[group_name]["weight_decay_final"] = 0.0
|
94 |
+
parameter_group_vars[group_name]["params"].append(param)
|
95 |
+
parameter_group_names[group_name]["params"].append(name)
|
96 |
+
# from unidepth.utils import is_main_process
|
97 |
+
# import json
|
98 |
+
# if is_main_process():
|
99 |
+
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
100 |
+
return list(parameter_group_vars.values()), [
|
101 |
+
v["lr"] for k, v in parameter_group_vars.items()
|
102 |
+
]
|
103 |
+
|
104 |
+
|
105 |
+
class Downsample(nn.Module):
|
106 |
+
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
107 |
+
super().__init__()
|
108 |
+
avg_stride = stride if dilation == 1 else 1
|
109 |
+
if stride > 1 or dilation > 1:
|
110 |
+
avg_pool_fn = (
|
111 |
+
AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
112 |
+
)
|
113 |
+
self.pool = avg_pool_fn(
|
114 |
+
2, avg_stride, ceil_mode=True, count_include_pad=False
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
self.pool = nn.Identity()
|
118 |
+
|
119 |
+
if in_chs != out_chs:
|
120 |
+
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
121 |
+
else:
|
122 |
+
self.conv = nn.Identity()
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
x = self.pool(x)
|
126 |
+
x = self.conv(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class ConvNeXtBlock(nn.Module):
|
131 |
+
"""ConvNeXt Block
|
132 |
+
There are two equivalent implementations:
|
133 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
134 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
135 |
+
|
136 |
+
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
137 |
+
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
138 |
+
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
in_chs: int,
|
144 |
+
out_chs: Optional[int] = None,
|
145 |
+
kernel_size: int = 7,
|
146 |
+
stride: int = 1,
|
147 |
+
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
148 |
+
mlp_ratio: float = 4,
|
149 |
+
conv_mlp: bool = False,
|
150 |
+
conv_bias: bool = True,
|
151 |
+
use_grn: bool = False,
|
152 |
+
ls_init_value: Optional[float] = 1e-6,
|
153 |
+
act_layer: Union[str, Callable] = "gelu",
|
154 |
+
norm_layer: Optional[Callable] = None,
|
155 |
+
drop_path: float = 0.0,
|
156 |
+
):
|
157 |
+
"""
|
158 |
+
|
159 |
+
Args:
|
160 |
+
in_chs: Block input channels.
|
161 |
+
out_chs: Block output channels (same as in_chs if None).
|
162 |
+
kernel_size: Depthwise convolution kernel size.
|
163 |
+
stride: Stride of depthwise convolution.
|
164 |
+
dilation: Tuple specifying input and output dilation of block.
|
165 |
+
mlp_ratio: MLP expansion ratio.
|
166 |
+
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
167 |
+
conv_bias: Apply bias for all convolution (linear) layers.
|
168 |
+
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
169 |
+
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
170 |
+
act_layer: Activation layer.
|
171 |
+
norm_layer: Normalization layer (defaults to LN if not specified).
|
172 |
+
drop_path: Stochastic depth probability.
|
173 |
+
"""
|
174 |
+
super().__init__()
|
175 |
+
out_chs = out_chs or in_chs
|
176 |
+
dilation = to_ntuple(2)(dilation)
|
177 |
+
act_layer = get_act_layer(act_layer)
|
178 |
+
if not norm_layer:
|
179 |
+
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
180 |
+
mlp_layer = partial(
|
181 |
+
GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp
|
182 |
+
)
|
183 |
+
self.use_conv_mlp = conv_mlp
|
184 |
+
self.conv_dw = create_conv2d(
|
185 |
+
in_chs,
|
186 |
+
out_chs,
|
187 |
+
kernel_size=kernel_size,
|
188 |
+
stride=stride,
|
189 |
+
dilation=dilation[0],
|
190 |
+
depthwise=True,
|
191 |
+
bias=conv_bias,
|
192 |
+
)
|
193 |
+
self.norm = norm_layer(out_chs)
|
194 |
+
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
195 |
+
self.gamma = (
|
196 |
+
nn.Parameter(ls_init_value * torch.ones(out_chs))
|
197 |
+
if ls_init_value is not None
|
198 |
+
else None
|
199 |
+
)
|
200 |
+
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
201 |
+
self.shortcut = Downsample(
|
202 |
+
in_chs, out_chs, stride=stride, dilation=dilation[0]
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
self.shortcut = nn.Identity()
|
206 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
shortcut = x
|
210 |
+
x = self.conv_dw(x.contiguous())
|
211 |
+
if self.use_conv_mlp:
|
212 |
+
x = self.norm(x)
|
213 |
+
x = self.mlp(x)
|
214 |
+
else:
|
215 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
216 |
+
x = self.norm(x)
|
217 |
+
x = self.mlp(x)
|
218 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
219 |
+
if self.gamma is not None:
|
220 |
+
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
221 |
+
|
222 |
+
x = self.drop_path(x) + self.shortcut(shortcut)
|
223 |
+
return x.contiguous()
|
224 |
+
|
225 |
+
|
226 |
+
class ConvNeXtStage(nn.Module):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
in_chs,
|
230 |
+
out_chs,
|
231 |
+
kernel_size=7,
|
232 |
+
stride=2,
|
233 |
+
depth=2,
|
234 |
+
dilation=(1, 1),
|
235 |
+
drop_path_rates=None,
|
236 |
+
ls_init_value=1.0,
|
237 |
+
conv_mlp=False,
|
238 |
+
conv_bias=True,
|
239 |
+
use_grn=False,
|
240 |
+
act_layer="gelu",
|
241 |
+
norm_layer=None,
|
242 |
+
norm_layer_cl=None,
|
243 |
+
):
|
244 |
+
super().__init__()
|
245 |
+
self.grad_checkpointing = False
|
246 |
+
|
247 |
+
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
|
248 |
+
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
|
249 |
+
pad = (
|
250 |
+
"same" if dilation[1] > 1 else 0
|
251 |
+
) # same padding needed if dilation used
|
252 |
+
self.downsample = nn.Sequential(
|
253 |
+
norm_layer(in_chs),
|
254 |
+
create_conv2d(
|
255 |
+
in_chs,
|
256 |
+
out_chs,
|
257 |
+
kernel_size=ds_ks,
|
258 |
+
stride=stride,
|
259 |
+
dilation=dilation[0],
|
260 |
+
padding=pad,
|
261 |
+
bias=conv_bias,
|
262 |
+
),
|
263 |
+
)
|
264 |
+
in_chs = out_chs
|
265 |
+
else:
|
266 |
+
self.downsample = nn.Identity()
|
267 |
+
|
268 |
+
drop_path_rates = drop_path_rates or [0.0] * depth
|
269 |
+
stage_blocks = []
|
270 |
+
for i in range(depth):
|
271 |
+
stage_blocks.append(
|
272 |
+
ConvNeXtBlock(
|
273 |
+
in_chs=in_chs,
|
274 |
+
out_chs=out_chs,
|
275 |
+
kernel_size=kernel_size,
|
276 |
+
dilation=dilation[1],
|
277 |
+
drop_path=drop_path_rates[i],
|
278 |
+
ls_init_value=ls_init_value,
|
279 |
+
conv_mlp=conv_mlp,
|
280 |
+
conv_bias=conv_bias,
|
281 |
+
use_grn=use_grn,
|
282 |
+
act_layer=act_layer,
|
283 |
+
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
|
284 |
+
)
|
285 |
+
)
|
286 |
+
in_chs = out_chs
|
287 |
+
self.blocks = nn.ModuleList(stage_blocks)
|
288 |
+
|
289 |
+
def forward(self, x):
|
290 |
+
xs = []
|
291 |
+
x = self.downsample(x)
|
292 |
+
for block in self.blocks:
|
293 |
+
if self.grad_checkpointing:
|
294 |
+
x = checkpoint(block, x)
|
295 |
+
else:
|
296 |
+
x = block(x)
|
297 |
+
xs.append(x)
|
298 |
+
return xs
|
299 |
+
|
300 |
+
|
301 |
+
class ConvNeXt(nn.Module):
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
in_chans: int = 3,
|
305 |
+
output_stride: int = 32,
|
306 |
+
depths: Tuple[int, ...] = (3, 3, 9, 3),
|
307 |
+
dims: Tuple[int, ...] = (96, 192, 384, 768),
|
308 |
+
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
|
309 |
+
ls_init_value: Optional[float] = 1e-6,
|
310 |
+
stem_type: str = "patch",
|
311 |
+
patch_size: int = 4,
|
312 |
+
conv_mlp: bool = False,
|
313 |
+
conv_bias: bool = True,
|
314 |
+
use_grn: bool = False,
|
315 |
+
act_layer: Union[str, Callable] = "gelu",
|
316 |
+
norm_layer: Optional[Union[str, Callable]] = None,
|
317 |
+
norm_eps: Optional[float] = None,
|
318 |
+
drop_path_rate: float = 0.0,
|
319 |
+
output_idx=[],
|
320 |
+
use_checkpoint=False,
|
321 |
+
):
|
322 |
+
"""
|
323 |
+
Args:
|
324 |
+
in_chans: Number of input image channels.
|
325 |
+
num_classes: Number of classes for classification head.
|
326 |
+
global_pool: Global pooling type.
|
327 |
+
output_stride: Output stride of network, one of (8, 16, 32).
|
328 |
+
depths: Number of blocks at each stage.
|
329 |
+
dims: Feature dimension at each stage.
|
330 |
+
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
|
331 |
+
ls_init_value: Init value for Layer Scale, disabled if None.
|
332 |
+
stem_type: Type of stem.
|
333 |
+
patch_size: Stem patch size for patch stem.
|
334 |
+
head_init_scale: Init scaling value for classifier weights and biases.
|
335 |
+
head_norm_first: Apply normalization before global pool + head.
|
336 |
+
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
|
337 |
+
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
|
338 |
+
conv_bias: Use bias layers w/ all convolutions.
|
339 |
+
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
|
340 |
+
act_layer: Activation layer type.
|
341 |
+
norm_layer: Normalization layer type.
|
342 |
+
drop_rate: Head pre-classifier dropout rate.
|
343 |
+
drop_path_rate: Stochastic depth drop rate.
|
344 |
+
"""
|
345 |
+
super().__init__()
|
346 |
+
self.num_layers = len(depths)
|
347 |
+
self.depths = output_idx
|
348 |
+
self.embed_dims = [
|
349 |
+
int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
|
350 |
+
]
|
351 |
+
self.embed_dim = dims[0]
|
352 |
+
|
353 |
+
assert output_stride in (8, 16, 32)
|
354 |
+
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
355 |
+
if norm_layer is None:
|
356 |
+
norm_layer = LayerNorm2d
|
357 |
+
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
358 |
+
if norm_eps is not None:
|
359 |
+
norm_layer = partial(norm_layer, eps=norm_eps)
|
360 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
361 |
+
else:
|
362 |
+
assert (
|
363 |
+
conv_mlp
|
364 |
+
), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
|
365 |
+
norm_layer_cl = norm_layer
|
366 |
+
if norm_eps is not None:
|
367 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
368 |
+
|
369 |
+
self.feature_info = []
|
370 |
+
|
371 |
+
assert stem_type in ("patch", "overlap", "overlap_tiered")
|
372 |
+
if stem_type == "patch":
|
373 |
+
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
374 |
+
self.stem = nn.Sequential(
|
375 |
+
nn.Conv2d(
|
376 |
+
in_chans,
|
377 |
+
dims[0],
|
378 |
+
kernel_size=patch_size,
|
379 |
+
stride=patch_size,
|
380 |
+
bias=conv_bias,
|
381 |
+
),
|
382 |
+
norm_layer(dims[0]),
|
383 |
+
)
|
384 |
+
stem_stride = patch_size
|
385 |
+
else:
|
386 |
+
mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0]
|
387 |
+
self.stem = nn.Sequential(
|
388 |
+
nn.Conv2d(
|
389 |
+
in_chans,
|
390 |
+
mid_chs,
|
391 |
+
kernel_size=3,
|
392 |
+
stride=2,
|
393 |
+
padding=1,
|
394 |
+
bias=conv_bias,
|
395 |
+
),
|
396 |
+
nn.Conv2d(
|
397 |
+
mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias
|
398 |
+
),
|
399 |
+
norm_layer(dims[0]),
|
400 |
+
)
|
401 |
+
stem_stride = 4
|
402 |
+
|
403 |
+
self.stages = nn.Sequential()
|
404 |
+
dp_rates = [
|
405 |
+
x.tolist()
|
406 |
+
for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
|
407 |
+
]
|
408 |
+
stages = []
|
409 |
+
prev_chs = dims[0]
|
410 |
+
curr_stride = stem_stride
|
411 |
+
dilation = 1
|
412 |
+
# 4 feature resolution stages, each consisting of multiple residual blocks
|
413 |
+
for i in range(4):
|
414 |
+
stride = 2 if curr_stride == 2 or i > 0 else 1
|
415 |
+
if curr_stride >= output_stride and stride > 1:
|
416 |
+
dilation *= stride
|
417 |
+
stride = 1
|
418 |
+
curr_stride *= stride
|
419 |
+
first_dilation = 1 if dilation in (1, 2) else 2
|
420 |
+
out_chs = dims[i]
|
421 |
+
stages.append(
|
422 |
+
ConvNeXtStage(
|
423 |
+
prev_chs,
|
424 |
+
out_chs,
|
425 |
+
kernel_size=kernel_sizes[i],
|
426 |
+
stride=stride,
|
427 |
+
dilation=(first_dilation, dilation),
|
428 |
+
depth=depths[i],
|
429 |
+
drop_path_rates=dp_rates[i],
|
430 |
+
ls_init_value=ls_init_value,
|
431 |
+
conv_mlp=conv_mlp,
|
432 |
+
conv_bias=conv_bias,
|
433 |
+
use_grn=use_grn,
|
434 |
+
act_layer=act_layer,
|
435 |
+
norm_layer=norm_layer,
|
436 |
+
norm_layer_cl=norm_layer_cl,
|
437 |
+
)
|
438 |
+
)
|
439 |
+
prev_chs = out_chs
|
440 |
+
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
441 |
+
self.feature_info += [
|
442 |
+
dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
|
443 |
+
]
|
444 |
+
self.stages = nn.ModuleList(stages)
|
445 |
+
self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
|
446 |
+
self.num_features = prev_chs
|
447 |
+
self.apply(self._init_weights)
|
448 |
+
self.set_grad_checkpointing(use_checkpoint)
|
449 |
+
|
450 |
+
def _init_weights(self, module):
|
451 |
+
if isinstance(module, nn.Conv2d):
|
452 |
+
trunc_normal_(module.weight, std=0.02)
|
453 |
+
if module.bias is not None:
|
454 |
+
nn.init.zeros_(module.bias)
|
455 |
+
elif isinstance(module, nn.Linear):
|
456 |
+
trunc_normal_(module.weight, std=0.02)
|
457 |
+
nn.init.zeros_(module.bias)
|
458 |
+
|
459 |
+
def forward(self, x, masks=None):
|
460 |
+
outs = []
|
461 |
+
x = self.stem(x)
|
462 |
+
if masks is not None:
|
463 |
+
masks = torch.nn.functional.interpolate(
|
464 |
+
masks.float(), size=x.shape[-2:], mode="nearest"
|
465 |
+
)
|
466 |
+
x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous()
|
467 |
+
for stage in self.stages:
|
468 |
+
xs = stage(x)
|
469 |
+
outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs])
|
470 |
+
x = xs[-1]
|
471 |
+
return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
|
472 |
+
|
473 |
+
@torch.jit.ignore
|
474 |
+
def group_matcher(self, coarse=False):
|
475 |
+
return dict(
|
476 |
+
stem=r"^stem",
|
477 |
+
blocks=(
|
478 |
+
r"^stages\.(\d+)"
|
479 |
+
if coarse
|
480 |
+
else [
|
481 |
+
(r"^stages\.(\d+)\.downsample", (0,)), # blocks
|
482 |
+
(r"^stages\.(\d+)\.blocks\.(\d+)", None),
|
483 |
+
(r"^norm_pre", (99999,)),
|
484 |
+
]
|
485 |
+
),
|
486 |
+
)
|
487 |
+
|
488 |
+
@torch.jit.ignore
|
489 |
+
def set_grad_checkpointing(self, enable=True):
|
490 |
+
for s in self.stages:
|
491 |
+
s.grad_checkpointing = enable
|
492 |
+
|
493 |
+
def freeze(self) -> None:
|
494 |
+
for module in self.modules():
|
495 |
+
module.eval()
|
496 |
+
for parameters in self.parameters():
|
497 |
+
parameters.requires_grad = False
|
498 |
+
|
499 |
+
def get_params(self, lr, wd, ld, *args, **kwargs):
|
500 |
+
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
|
501 |
+
return encoder_p, encoder_lr
|
502 |
+
|
503 |
+
def no_weight_decay(self):
|
504 |
+
return {"mask_token"}
|
505 |
+
|
506 |
+
@classmethod
|
507 |
+
def build(cls, config):
|
508 |
+
obj = globals()[config["model"]["encoder"]["name"]](config)
|
509 |
+
return obj
|
510 |
+
|
511 |
+
|
512 |
+
def checkpoint_filter_fn(state_dict, model):
|
513 |
+
"""Remap FB checkpoints -> timm"""
|
514 |
+
if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict:
|
515 |
+
return state_dict # non-FB checkpoint
|
516 |
+
if "model" in state_dict:
|
517 |
+
state_dict = state_dict["model"]
|
518 |
+
|
519 |
+
out_dict = {}
|
520 |
+
if "visual.trunk.stem.0.weight" in state_dict:
|
521 |
+
out_dict = {
|
522 |
+
k.replace("visual.trunk.", ""): v
|
523 |
+
for k, v in state_dict.items()
|
524 |
+
if k.startswith("visual.trunk.")
|
525 |
+
}
|
526 |
+
if "visual.head.proj.weight" in state_dict:
|
527 |
+
out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"]
|
528 |
+
out_dict["head.fc.bias"] = torch.zeros(
|
529 |
+
state_dict["visual.head.proj.weight"].shape[0]
|
530 |
+
)
|
531 |
+
elif "visual.head.mlp.fc1.weight" in state_dict:
|
532 |
+
out_dict["head.pre_logits.fc.weight"] = state_dict[
|
533 |
+
"visual.head.mlp.fc1.weight"
|
534 |
+
]
|
535 |
+
out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"]
|
536 |
+
out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"]
|
537 |
+
out_dict["head.fc.bias"] = torch.zeros(
|
538 |
+
state_dict["visual.head.mlp.fc2.weight"].shape[0]
|
539 |
+
)
|
540 |
+
return out_dict
|
541 |
+
|
542 |
+
import re
|
543 |
+
|
544 |
+
for k, v in state_dict.items():
|
545 |
+
k = k.replace("downsample_layers.0.", "stem.")
|
546 |
+
k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
|
547 |
+
k = re.sub(
|
548 |
+
r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
|
549 |
+
)
|
550 |
+
k = k.replace("dwconv", "conv_dw")
|
551 |
+
k = k.replace("pwconv", "mlp.fc")
|
552 |
+
if "grn" in k:
|
553 |
+
k = k.replace("grn.beta", "mlp.grn.bias")
|
554 |
+
k = k.replace("grn.gamma", "mlp.grn.weight")
|
555 |
+
v = v.reshape(v.shape[-1])
|
556 |
+
k = k.replace("head.", "head.fc.")
|
557 |
+
if k.startswith("norm."):
|
558 |
+
k = k.replace("norm", "head.norm")
|
559 |
+
if v.ndim == 2 and "head" not in k:
|
560 |
+
model_shape = model.state_dict()[k].shape
|
561 |
+
v = v.reshape(model_shape)
|
562 |
+
out_dict[k] = v
|
563 |
+
|
564 |
+
return out_dict
|
565 |
+
|
566 |
+
|
567 |
+
HF_URL = {
|
568 |
+
"convnext_xxlarge_pt": (
|
569 |
+
"laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup",
|
570 |
+
"open_clip_pytorch_model.bin",
|
571 |
+
),
|
572 |
+
"convnext_large_pt": (
|
573 |
+
"laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
|
574 |
+
"open_clip_pytorch_model.bin",
|
575 |
+
),
|
576 |
+
"convnext_large": (
|
577 |
+
"timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384",
|
578 |
+
"pytorch_model.bin",
|
579 |
+
),
|
580 |
+
}
|
unidepth/models/backbones/convnext2.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from timm.models.layers import DropPath, trunc_normal_
|
5 |
+
|
6 |
+
|
7 |
+
def get_num_layer_for_convnext_single(var_name, depths):
|
8 |
+
"""
|
9 |
+
Each layer is assigned distinctive layer ids
|
10 |
+
"""
|
11 |
+
if var_name.startswith("downsample_layers"):
|
12 |
+
stage_id = int(var_name.split(".")[1])
|
13 |
+
layer_id = sum(depths[:stage_id]) + 1
|
14 |
+
return layer_id
|
15 |
+
|
16 |
+
elif var_name.startswith("stages"):
|
17 |
+
stage_id = int(var_name.split(".")[1])
|
18 |
+
block_id = int(var_name.split(".")[2])
|
19 |
+
layer_id = sum(depths[:stage_id]) + block_id + 1
|
20 |
+
return layer_id
|
21 |
+
|
22 |
+
else:
|
23 |
+
return sum(depths) + 1
|
24 |
+
|
25 |
+
|
26 |
+
def get_num_layer_for_convnext(var_name):
|
27 |
+
"""
|
28 |
+
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
|
29 |
+
consecutive blocks, including possible neighboring downsample layers;
|
30 |
+
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
|
31 |
+
"""
|
32 |
+
num_max_layer = 12
|
33 |
+
if var_name.startswith("downsample_layers"):
|
34 |
+
stage_id = int(var_name.split(".")[1])
|
35 |
+
if stage_id == 0:
|
36 |
+
layer_id = 0
|
37 |
+
elif stage_id == 1 or stage_id == 2:
|
38 |
+
layer_id = stage_id + 1
|
39 |
+
elif stage_id == 3:
|
40 |
+
layer_id = 12
|
41 |
+
return layer_id
|
42 |
+
|
43 |
+
elif var_name.startswith("stages"):
|
44 |
+
stage_id = int(var_name.split(".")[1])
|
45 |
+
block_id = int(var_name.split(".")[2])
|
46 |
+
if stage_id == 0 or stage_id == 1:
|
47 |
+
layer_id = stage_id + 1
|
48 |
+
elif stage_id == 2:
|
49 |
+
layer_id = 3 + block_id // 3
|
50 |
+
elif stage_id == 3:
|
51 |
+
layer_id = 12
|
52 |
+
return layer_id
|
53 |
+
else:
|
54 |
+
return num_max_layer + 1
|
55 |
+
|
56 |
+
|
57 |
+
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
|
58 |
+
parameter_group_names = {}
|
59 |
+
parameter_group_vars = {}
|
60 |
+
skip = {}
|
61 |
+
if skip_list is not None:
|
62 |
+
skip = skip_list
|
63 |
+
elif hasattr(model, "no_weight_decay"):
|
64 |
+
skip = model.no_weight_decay()
|
65 |
+
num_layers = 12 # sum(model.depths)
|
66 |
+
layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
|
67 |
+
for name, param in model.named_parameters():
|
68 |
+
if not param.requires_grad:
|
69 |
+
continue # frozen weights
|
70 |
+
if (
|
71 |
+
len(param.shape) == 1
|
72 |
+
or name.endswith(".bias")
|
73 |
+
or name in skip
|
74 |
+
or name.endswith(".gamma")
|
75 |
+
or name.endswith(".beta")
|
76 |
+
):
|
77 |
+
group_name = "no_decay"
|
78 |
+
this_weight_decay = 0.0
|
79 |
+
else:
|
80 |
+
group_name = "decay"
|
81 |
+
this_weight_decay = wd
|
82 |
+
|
83 |
+
# layer_id = get_num_layer_for_convnext_single(name, model.depths)
|
84 |
+
layer_id = get_num_layer_for_convnext(name)
|
85 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
86 |
+
|
87 |
+
if group_name not in parameter_group_names:
|
88 |
+
scale = layer_scale[layer_id]
|
89 |
+
cur_lr = lr * scale
|
90 |
+
|
91 |
+
parameter_group_names[group_name] = {
|
92 |
+
"weight_decay": this_weight_decay,
|
93 |
+
"params": [],
|
94 |
+
"lr_scale": scale,
|
95 |
+
"lr": cur_lr,
|
96 |
+
}
|
97 |
+
parameter_group_vars[group_name] = {
|
98 |
+
"weight_decay": this_weight_decay,
|
99 |
+
"params": [],
|
100 |
+
"lr_scale": scale,
|
101 |
+
"lr": cur_lr,
|
102 |
+
}
|
103 |
+
parameter_group_vars[group_name]["params"].append(param)
|
104 |
+
parameter_group_names[group_name]["params"].append(name)
|
105 |
+
# if is_main_process():
|
106 |
+
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
107 |
+
return list(parameter_group_vars.values()), [
|
108 |
+
v["lr"] for k, v in parameter_group_vars.items()
|
109 |
+
]
|
110 |
+
|
111 |
+
|
112 |
+
class LayerNorm(nn.Module):
|
113 |
+
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
114 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
115 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
116 |
+
with shape (batch_size, channels, height, width).
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
120 |
+
super().__init__()
|
121 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
122 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
123 |
+
self.eps = eps
|
124 |
+
self.data_format = data_format
|
125 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
126 |
+
raise NotImplementedError
|
127 |
+
self.normalized_shape = (normalized_shape,)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
if self.data_format == "channels_last":
|
131 |
+
return F.layer_norm(
|
132 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
133 |
+
)
|
134 |
+
elif self.data_format == "channels_first":
|
135 |
+
u = x.mean(1, keepdim=True)
|
136 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
137 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
138 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class GRN(nn.Module):
|
143 |
+
"""GRN (Global Response Normalization) layer"""
|
144 |
+
|
145 |
+
def __init__(self, dim):
|
146 |
+
super().__init__()
|
147 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
148 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
152 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
153 |
+
return self.gamma * (x * Nx) + self.beta + x
|
154 |
+
|
155 |
+
|
156 |
+
class Block(nn.Module):
|
157 |
+
"""ConvNeXtV2 Block.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
dim (int): Number of input channels.
|
161 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False):
|
165 |
+
super().__init__()
|
166 |
+
self.dwconv = nn.Conv2d(
|
167 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
168 |
+
) # depthwise conv
|
169 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
170 |
+
self.pwconv1 = nn.Linear(
|
171 |
+
dim, mult * dim
|
172 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
173 |
+
self.act = nn.GELU()
|
174 |
+
self.grn = GRN(mult * dim)
|
175 |
+
self.pwconv2 = nn.Linear(mult * dim, dim)
|
176 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
177 |
+
self.use_checkpoint = use_checkpoint
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
input = x
|
181 |
+
x = self.dwconv(x)
|
182 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
183 |
+
x = self.norm(x)
|
184 |
+
x = self.pwconv1(x)
|
185 |
+
x = self.act(x)
|
186 |
+
x = self.grn(x)
|
187 |
+
x = self.pwconv2(x)
|
188 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
189 |
+
|
190 |
+
x = input + self.drop_path(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class ConvNeXtV2(nn.Module):
|
195 |
+
"""ConvNeXt V2
|
196 |
+
|
197 |
+
Args:
|
198 |
+
in_chans (int): Number of input image channels. Default: 3
|
199 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
200 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
201 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
202 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
203 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
204 |
+
"""
|
205 |
+
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
in_chans=3,
|
209 |
+
depths=[3, 3, 9, 3],
|
210 |
+
dims=96,
|
211 |
+
drop_path_rate=0.0,
|
212 |
+
output_idx=[],
|
213 |
+
use_checkpoint=False,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
self.num_layers = len(depths)
|
217 |
+
self.depths = output_idx
|
218 |
+
self.embed_dims = [
|
219 |
+
int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
|
220 |
+
]
|
221 |
+
self.embed_dim = dims[0]
|
222 |
+
|
223 |
+
self.downsample_layers = (
|
224 |
+
nn.ModuleList()
|
225 |
+
) # stem and 3 intermediate downsampling conv layers
|
226 |
+
stem = nn.Sequential(
|
227 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
228 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
229 |
+
)
|
230 |
+
self.downsample_layers.append(stem)
|
231 |
+
for i in range(3):
|
232 |
+
downsample_layer = nn.Sequential(
|
233 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
234 |
+
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
235 |
+
)
|
236 |
+
self.downsample_layers.append(downsample_layer)
|
237 |
+
|
238 |
+
self.stages = (
|
239 |
+
nn.ModuleList()
|
240 |
+
) # 4 feature resolution stages, each consisting of multiple residual blocks
|
241 |
+
self.out_norms = nn.ModuleList()
|
242 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
243 |
+
cur = 0
|
244 |
+
for i in range(4):
|
245 |
+
stage = nn.ModuleList(
|
246 |
+
[
|
247 |
+
Block(
|
248 |
+
dim=dims[i],
|
249 |
+
drop_path=dp_rates[cur + j],
|
250 |
+
use_checkpoint=use_checkpoint,
|
251 |
+
)
|
252 |
+
for j in range(depths[i])
|
253 |
+
]
|
254 |
+
)
|
255 |
+
self.stages.append(stage)
|
256 |
+
cur += depths[i]
|
257 |
+
|
258 |
+
self.apply(self._init_weights)
|
259 |
+
|
260 |
+
def _init_weights(self, m):
|
261 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
262 |
+
trunc_normal_(m.weight, std=0.02)
|
263 |
+
nn.init.constant_(m.bias, 0)
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
outs = []
|
267 |
+
for i in range(4):
|
268 |
+
x = self.downsample_layers[i](x)
|
269 |
+
for stage in self.stages[i]:
|
270 |
+
x = stage(x)
|
271 |
+
outs.append(x.permute(0, 2, 3, 1))
|
272 |
+
cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
|
273 |
+
return outs, cls_tokens
|
274 |
+
|
275 |
+
def get_params(self, lr, wd, ld, *args, **kwargs):
|
276 |
+
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
|
277 |
+
return encoder_p, encoder_lr
|
278 |
+
|
279 |
+
def freeze(self) -> None:
|
280 |
+
for module in self.modules():
|
281 |
+
module.eval()
|
282 |
+
for parameters in self.parameters():
|
283 |
+
parameters.requires_grad = False
|
284 |
+
|
285 |
+
@classmethod
|
286 |
+
def build(cls, config):
|
287 |
+
obj = globals()[config["model"]["encoder"]["name"]](config)
|
288 |
+
return obj
|
unidepth/models/backbones/dinov2.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from functools import partial
|
4 |
+
from typing import Callable, Sequence
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
|
10 |
+
from .metadinov2 import Attention, MemEffAttention, Mlp
|
11 |
+
from .metadinov2 import NestedTensorBlock as Block
|
12 |
+
from .metadinov2 import PatchEmbed, SwiGLUFFNFused
|
13 |
+
|
14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
15 |
+
logger = logging.getLogger("dinov2")
|
16 |
+
|
17 |
+
|
18 |
+
def named_apply(
|
19 |
+
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
|
20 |
+
) -> nn.Module:
|
21 |
+
if not depth_first and include_root:
|
22 |
+
fn(module=module, name=name)
|
23 |
+
for child_name, child_module in module.named_children():
|
24 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
25 |
+
named_apply(
|
26 |
+
fn=fn,
|
27 |
+
module=child_module,
|
28 |
+
name=child_name,
|
29 |
+
depth_first=depth_first,
|
30 |
+
include_root=True,
|
31 |
+
)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
|
38 |
+
parameter_group_names = {}
|
39 |
+
parameter_group_vars = {}
|
40 |
+
skip = {}
|
41 |
+
if skip_list is not None:
|
42 |
+
skip = skip_list
|
43 |
+
elif hasattr(model, "no_weight_decay"):
|
44 |
+
skip = model.no_weight_decay()
|
45 |
+
|
46 |
+
num_layers = model.n_blocks
|
47 |
+
layer_scale = list(ld ** (num_layers - i) for i in range(num_layers))
|
48 |
+
|
49 |
+
for name, param in model.named_parameters():
|
50 |
+
if not param.requires_grad:
|
51 |
+
continue
|
52 |
+
|
53 |
+
if len(param.shape) == 1: # norm
|
54 |
+
group_name = "no_decay"
|
55 |
+
this_wd = 0.0
|
56 |
+
# layer scale, bias beta?
|
57 |
+
elif (
|
58 |
+
name in skip
|
59 |
+
or name.endswith(".gamma")
|
60 |
+
or name.endswith(".beta")
|
61 |
+
or name.endswith(".bias")
|
62 |
+
):
|
63 |
+
group_name = "no_decay"
|
64 |
+
this_wd = 0.0
|
65 |
+
elif "cls_token" in name or "pos_embed" in name or "mask_token" in name:
|
66 |
+
group_name = "no_decay"
|
67 |
+
this_wd = 0.0
|
68 |
+
else:
|
69 |
+
group_name = "decay"
|
70 |
+
this_wd = wd
|
71 |
+
|
72 |
+
if name.startswith("blocks"):
|
73 |
+
layer_id = int(name.split(".")[1])
|
74 |
+
elif name.startswith("patch_embed"):
|
75 |
+
layer_id = 0
|
76 |
+
else:
|
77 |
+
layer_id = 0
|
78 |
+
|
79 |
+
group_name = f"layer_{layer_id}_{group_name}"
|
80 |
+
|
81 |
+
if group_name not in parameter_group_names:
|
82 |
+
scale = layer_scale[layer_id]
|
83 |
+
cur_lr = lr * scale
|
84 |
+
|
85 |
+
parameter_group_names[group_name] = {
|
86 |
+
"weight_decay": this_wd,
|
87 |
+
"params": [],
|
88 |
+
"lr_init": cur_lr,
|
89 |
+
"lr_base": lr,
|
90 |
+
"lr": cur_lr,
|
91 |
+
}
|
92 |
+
parameter_group_vars[group_name] = {
|
93 |
+
"weight_decay": this_wd,
|
94 |
+
"params": [],
|
95 |
+
"lr_init": cur_lr,
|
96 |
+
"lr_base": lr,
|
97 |
+
"lr": cur_lr,
|
98 |
+
}
|
99 |
+
parameter_group_vars[group_name]["params"].append(param)
|
100 |
+
parameter_group_names[group_name]["params"].append(name)
|
101 |
+
return list(parameter_group_vars.values()), [
|
102 |
+
v["lr"] for k, v in parameter_group_vars.items()
|
103 |
+
]
|
104 |
+
|
105 |
+
|
106 |
+
class BlockChunk(nn.ModuleList):
|
107 |
+
def forward(self, x):
|
108 |
+
for b in self:
|
109 |
+
x = b(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class DinoVisionTransformer(nn.Module):
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
img_size=224,
|
117 |
+
patch_size=16,
|
118 |
+
in_chans=3,
|
119 |
+
embed_dim=768,
|
120 |
+
depth=12,
|
121 |
+
num_heads=12,
|
122 |
+
mlp_ratio=4.0,
|
123 |
+
qkv_bias=True,
|
124 |
+
ffn_bias=True,
|
125 |
+
proj_bias=True,
|
126 |
+
drop_path_rate=0.0,
|
127 |
+
drop_path_uniform=False,
|
128 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
129 |
+
embed_layer=PatchEmbed,
|
130 |
+
act_layer=nn.GELU,
|
131 |
+
block_fn=Block,
|
132 |
+
ffn_layer="mlp",
|
133 |
+
block_chunks=1,
|
134 |
+
output_idx=[5, 12, 18, 24],
|
135 |
+
checkpoint: bool = False,
|
136 |
+
num_register_tokens=0,
|
137 |
+
interpolate_antialias=False,
|
138 |
+
interpolate_offset=0.0,
|
139 |
+
use_norm=False,
|
140 |
+
):
|
141 |
+
"""
|
142 |
+
Args:
|
143 |
+
img_size (int, tuple): input image size
|
144 |
+
patch_size (int, tuple): patch size
|
145 |
+
in_chans (int): number of input channels
|
146 |
+
embed_dim (int): embedding dimension
|
147 |
+
depth (int): depth of transformer
|
148 |
+
num_heads (int): number of attention heads
|
149 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
150 |
+
qkv_bias (bool): enable bias for qkv if True
|
151 |
+
proj_bias (bool): enable bias for proj in attn if True
|
152 |
+
ffn_bias (bool): enable bias for ffn if True
|
153 |
+
drop_path_rate (float): stochastic depth rate
|
154 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
155 |
+
weight_init (str): weight init scheme
|
156 |
+
init_values (float): layer-scale init values
|
157 |
+
embed_layer (nn.Module): patch embedding layer
|
158 |
+
act_layer (nn.Module): MLP activation layer
|
159 |
+
block_fn (nn.Module): transformer block class
|
160 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
161 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
162 |
+
"""
|
163 |
+
super().__init__()
|
164 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
165 |
+
|
166 |
+
self.num_features = self.embed_dim = (
|
167 |
+
embed_dim # num_features for consistency with other models
|
168 |
+
)
|
169 |
+
self.embed_dims = [embed_dim] * output_idx[-1]
|
170 |
+
self.num_tokens = 1
|
171 |
+
self.n_blocks = depth
|
172 |
+
self.num_heads = num_heads
|
173 |
+
self.patch_size = patch_size
|
174 |
+
self.depths = output_idx
|
175 |
+
self.checkpoint = checkpoint
|
176 |
+
self.num_register_tokens = num_register_tokens
|
177 |
+
self.interpolate_antialias = interpolate_antialias
|
178 |
+
self.interpolate_offset = interpolate_offset
|
179 |
+
|
180 |
+
self.patch_embed = embed_layer(
|
181 |
+
img_size=img_size,
|
182 |
+
patch_size=patch_size,
|
183 |
+
in_chans=in_chans,
|
184 |
+
embed_dim=embed_dim,
|
185 |
+
)
|
186 |
+
num_patches = self.patch_embed.num_patches
|
187 |
+
|
188 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
189 |
+
self.pos_embed = nn.Parameter(
|
190 |
+
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
191 |
+
)
|
192 |
+
assert num_register_tokens >= 0
|
193 |
+
self.register_tokens = nn.Parameter(
|
194 |
+
torch.zeros(1, max(1, num_register_tokens), embed_dim)
|
195 |
+
)
|
196 |
+
|
197 |
+
if drop_path_uniform is True:
|
198 |
+
dpr = [drop_path_rate] * depth
|
199 |
+
else:
|
200 |
+
dpr = [
|
201 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
202 |
+
] # stochastic depth decay rule
|
203 |
+
|
204 |
+
if ffn_layer == "mlp":
|
205 |
+
logger.info("using MLP layer as FFN")
|
206 |
+
ffn_layer = Mlp
|
207 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
208 |
+
logger.info("using SwiGLU layer as FFN")
|
209 |
+
ffn_layer = SwiGLUFFNFused
|
210 |
+
elif ffn_layer == "identity":
|
211 |
+
logger.info("using Identity layer as FFN")
|
212 |
+
|
213 |
+
def f(*args, **kwargs):
|
214 |
+
return nn.Identity()
|
215 |
+
|
216 |
+
ffn_layer = f
|
217 |
+
else:
|
218 |
+
raise NotImplementedError
|
219 |
+
|
220 |
+
blocks_list = [
|
221 |
+
block_fn(
|
222 |
+
dim=embed_dim,
|
223 |
+
num_heads=num_heads,
|
224 |
+
mlp_ratio=mlp_ratio,
|
225 |
+
qkv_bias=qkv_bias,
|
226 |
+
proj_bias=proj_bias,
|
227 |
+
ffn_bias=ffn_bias,
|
228 |
+
drop_path=dpr[i],
|
229 |
+
norm_layer=norm_layer,
|
230 |
+
act_layer=act_layer,
|
231 |
+
ffn_layer=ffn_layer,
|
232 |
+
init_values=init_values,
|
233 |
+
)
|
234 |
+
for i in range(depth)
|
235 |
+
]
|
236 |
+
if block_chunks > 0:
|
237 |
+
self.chunked_blocks = True
|
238 |
+
chunked_blocks = []
|
239 |
+
chunksize = depth // block_chunks
|
240 |
+
for i in range(0, depth, chunksize):
|
241 |
+
# this is to keep the block index consistent if we chunk the block list
|
242 |
+
chunked_blocks.append(
|
243 |
+
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
244 |
+
)
|
245 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
246 |
+
else:
|
247 |
+
self.chunked_blocks = False
|
248 |
+
self.blocks = nn.ModuleList(blocks_list)
|
249 |
+
|
250 |
+
self.norm = norm_layer(embed_dim)
|
251 |
+
self.use_norm = use_norm
|
252 |
+
self.head = nn.Identity()
|
253 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
254 |
+
self.init_weights()
|
255 |
+
|
256 |
+
def init_weights(self):
|
257 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
258 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
259 |
+
if self.num_register_tokens:
|
260 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
261 |
+
named_apply(init_weights_vit_timm, self)
|
262 |
+
|
263 |
+
def interpolate_pos_encoding(self, x, w, h):
|
264 |
+
previous_dtype = x.dtype
|
265 |
+
npatch = x.shape[1] - 1
|
266 |
+
N = self.pos_embed.shape[1] - 1
|
267 |
+
if npatch == N and w == h:
|
268 |
+
return self.pos_embed
|
269 |
+
pos_embed = self.pos_embed.float()
|
270 |
+
class_pos_embed = pos_embed[:, 0]
|
271 |
+
patch_pos_embed = pos_embed[:, 1:]
|
272 |
+
dim = x.shape[-1]
|
273 |
+
w0 = w // self.patch_size
|
274 |
+
h0 = h // self.patch_size
|
275 |
+
|
276 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
277 |
+
assert N == M * M
|
278 |
+
kwargs = {}
|
279 |
+
if self.interpolate_offset:
|
280 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
281 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
282 |
+
sx = float(w0 + self.interpolate_offset) / M
|
283 |
+
sy = float(h0 + self.interpolate_offset) / M
|
284 |
+
kwargs["scale_factor"] = (sx, sy)
|
285 |
+
else:
|
286 |
+
# Simply specify an output size instead of a scale factor
|
287 |
+
kwargs["size"] = (w0, h0)
|
288 |
+
|
289 |
+
patch_pos_embed = nn.functional.interpolate(
|
290 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
291 |
+
mode="bicubic",
|
292 |
+
antialias=self.interpolate_antialias,
|
293 |
+
**kwargs,
|
294 |
+
)
|
295 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
296 |
+
|
297 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
298 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
299 |
+
previous_dtype
|
300 |
+
)
|
301 |
+
|
302 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
303 |
+
B, nc, w, h = x.shape
|
304 |
+
x = self.patch_embed(x)
|
305 |
+
if masks is not None:
|
306 |
+
masks = masks.bool().view(B, -1, 1)
|
307 |
+
x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x)
|
308 |
+
|
309 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
310 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
311 |
+
|
312 |
+
if self.num_register_tokens:
|
313 |
+
x = torch.cat(
|
314 |
+
(x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
|
315 |
+
dim=1,
|
316 |
+
)
|
317 |
+
return x
|
318 |
+
|
319 |
+
def forward(self, x, masks=None):
|
320 |
+
shapes = [val // self.patch_size for val in x.shape[-2:]]
|
321 |
+
batch_size = x.shape[0]
|
322 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
323 |
+
outputs = []
|
324 |
+
for i, blk in enumerate(self.blocks):
|
325 |
+
x = blk(x)
|
326 |
+
outputs.append(x)
|
327 |
+
|
328 |
+
if self.use_norm:
|
329 |
+
outputs = [self.norm(out) for out in outputs]
|
330 |
+
class_tokens = [out[:, :1] for out in outputs]
|
331 |
+
outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs]
|
332 |
+
outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs]
|
333 |
+
|
334 |
+
return (outputs, class_tokens)
|
335 |
+
|
336 |
+
def get_params(self, lr, wd, ld, *args, **kwargs):
|
337 |
+
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
|
338 |
+
return encoder_p, encoder_lr
|
339 |
+
|
340 |
+
def freeze(self) -> None:
|
341 |
+
for module in self.modules():
|
342 |
+
module.eval()
|
343 |
+
for parameters in self.parameters():
|
344 |
+
parameters.requires_grad = False
|
345 |
+
|
346 |
+
def train(self, mode=True):
|
347 |
+
super().train(mode)
|
348 |
+
self.mask_token.requires_grad = False
|
349 |
+
self.register_tokens.requires_grad = False
|
350 |
+
|
351 |
+
|
352 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
353 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
354 |
+
if isinstance(module, nn.Linear):
|
355 |
+
trunc_normal_(module.weight, std=0.02)
|
356 |
+
if module.bias is not None:
|
357 |
+
nn.init.zeros_(module.bias)
|
358 |
+
|
359 |
+
|
360 |
+
def vit_small(patch_size=16, num_register_tokens=0, export=False, **kwargs):
|
361 |
+
model = DinoVisionTransformer(
|
362 |
+
patch_size=patch_size,
|
363 |
+
embed_dim=384,
|
364 |
+
depth=12,
|
365 |
+
num_heads=6,
|
366 |
+
mlp_ratio=4,
|
367 |
+
num_register_tokens=num_register_tokens,
|
368 |
+
block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
|
369 |
+
**kwargs,
|
370 |
+
)
|
371 |
+
return model
|
372 |
+
|
373 |
+
|
374 |
+
def vit_base(patch_size=16, num_register_tokens=0, export=False, **kwargs):
|
375 |
+
model = DinoVisionTransformer(
|
376 |
+
patch_size=patch_size,
|
377 |
+
embed_dim=768,
|
378 |
+
depth=12,
|
379 |
+
num_heads=12,
|
380 |
+
mlp_ratio=4,
|
381 |
+
num_register_tokens=num_register_tokens,
|
382 |
+
block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
|
383 |
+
**kwargs,
|
384 |
+
)
|
385 |
+
return model
|
386 |
+
|
387 |
+
|
388 |
+
def vit_large(patch_size=16, num_register_tokens=0, export=False, **kwargs):
|
389 |
+
model = DinoVisionTransformer(
|
390 |
+
patch_size=patch_size,
|
391 |
+
embed_dim=1024,
|
392 |
+
depth=24,
|
393 |
+
num_heads=16,
|
394 |
+
mlp_ratio=4,
|
395 |
+
num_register_tokens=num_register_tokens,
|
396 |
+
block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
|
397 |
+
**kwargs,
|
398 |
+
)
|
399 |
+
return model
|
400 |
+
|
401 |
+
|
402 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
|
403 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
404 |
+
return f"dinov2_{compact_arch_name}{patch_size}"
|
405 |
+
|
406 |
+
|
407 |
+
def _make_dinov2_model(
|
408 |
+
*,
|
409 |
+
arch_name: str = "vit_large",
|
410 |
+
img_size: int = 518,
|
411 |
+
patch_size: int = 14,
|
412 |
+
init_values: float = 1.0,
|
413 |
+
ffn_layer: str = "mlp",
|
414 |
+
block_chunks: int = 0,
|
415 |
+
pretrained: str = "",
|
416 |
+
output_idx: Sequence[int] = [],
|
417 |
+
num_register_tokens: int = 0,
|
418 |
+
drop_path_rate: float = 0.0,
|
419 |
+
use_norm: bool = False,
|
420 |
+
export: bool = False,
|
421 |
+
interpolate_offset: float = 0.0,
|
422 |
+
**kwargs,
|
423 |
+
):
|
424 |
+
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
425 |
+
|
426 |
+
vit_kwargs = dict(
|
427 |
+
img_size=img_size,
|
428 |
+
patch_size=patch_size,
|
429 |
+
init_values=init_values,
|
430 |
+
ffn_layer=ffn_layer,
|
431 |
+
block_chunks=block_chunks,
|
432 |
+
output_idx=output_idx,
|
433 |
+
drop_path_rate=drop_path_rate,
|
434 |
+
num_register_tokens=num_register_tokens,
|
435 |
+
use_norm=use_norm,
|
436 |
+
export=export,
|
437 |
+
interpolate_offset=interpolate_offset,
|
438 |
+
)
|
439 |
+
vit_kwargs.update(**kwargs)
|
440 |
+
model = eval(arch_name)(**vit_kwargs)
|
441 |
+
if pretrained == "":
|
442 |
+
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}"
|
443 |
+
if num_register_tokens > 0:
|
444 |
+
url += "_reg4"
|
445 |
+
url += "_pretrain.pth"
|
446 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
447 |
+
url, map_location="cpu", progress=False
|
448 |
+
)
|
449 |
+
info = model.load_state_dict(state_dict, strict=False)
|
450 |
+
print(info)
|
451 |
+
elif pretrained is not None:
|
452 |
+
state_dict = torch.load(pretrained, map_location="cpu")
|
453 |
+
info = model.load_state_dict(state_dict, strict=False)
|
454 |
+
print(f"loading from {pretrained} with:", info)
|
455 |
+
return model
|
unidepth/models/backbones/metadinov2/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .attention import Attention, MemEffAttention
|
8 |
+
from .block import NestedTensorBlock
|
9 |
+
from .dino_head import DINOHead
|
10 |
+
from .mlp import Mlp
|
11 |
+
from .patch_embed import PatchEmbed
|
12 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
unidepth/models/backbones/metadinov2/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (592 Bytes). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/attention.cpython-311.pyc
ADDED
Binary file (4.46 kB). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/block.cpython-311.pyc
ADDED
Binary file (15.9 kB). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/dino_head.cpython-311.pyc
ADDED
Binary file (3.94 kB). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/drop_path.cpython-311.pyc
ADDED
Binary file (1.86 kB). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/layer_scale.cpython-311.pyc
ADDED
Binary file (1.62 kB). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/mlp.cpython-311.pyc
ADDED
Binary file (2.08 kB). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/patch_embed.cpython-311.pyc
ADDED
Binary file (4.49 kB). View file
|
|
unidepth/models/backbones/metadinov2/__pycache__/swiglu_ffn.cpython-311.pyc
ADDED
Binary file (3.29 kB). View file
|
|
unidepth/models/backbones/metadinov2/attention.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
logger = logging.getLogger("dinov2")
|
17 |
+
|
18 |
+
|
19 |
+
try:
|
20 |
+
from xformers.ops import fmha, memory_efficient_attention, unbind
|
21 |
+
|
22 |
+
XFORMERS_AVAILABLE = True
|
23 |
+
except ImportError:
|
24 |
+
logger.warning("xFormers not available")
|
25 |
+
XFORMERS_AVAILABLE = False
|
26 |
+
|
27 |
+
|
28 |
+
class Attention(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dim: int,
|
32 |
+
num_heads: int = 8,
|
33 |
+
qkv_bias: bool = False,
|
34 |
+
proj_bias: bool = True,
|
35 |
+
attn_drop: float = 0.0,
|
36 |
+
proj_drop: float = 0.0,
|
37 |
+
) -> None:
|
38 |
+
super().__init__()
|
39 |
+
self.num_heads = num_heads
|
40 |
+
head_dim = dim // num_heads
|
41 |
+
self.scale = head_dim**-0.5
|
42 |
+
|
43 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
44 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
45 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
46 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
47 |
+
|
48 |
+
def forward(self, x: Tensor) -> Tensor:
|
49 |
+
B, N, C = x.shape
|
50 |
+
qkv = (
|
51 |
+
self.qkv(x)
|
52 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
53 |
+
.permute(2, 0, 3, 1, 4)
|
54 |
+
)
|
55 |
+
|
56 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
57 |
+
attn = q @ k.transpose(-2, -1)
|
58 |
+
|
59 |
+
attn = attn.softmax(dim=-1)
|
60 |
+
attn = self.attn_drop(attn)
|
61 |
+
|
62 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
63 |
+
x = self.proj(x)
|
64 |
+
x = self.proj_drop(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class MemEffAttention(Attention):
|
69 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
70 |
+
if not XFORMERS_AVAILABLE:
|
71 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
72 |
+
return super().forward(x)
|
73 |
+
|
74 |
+
B, N, C = x.shape
|
75 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
76 |
+
|
77 |
+
q, k, v = unbind(qkv, 2)
|
78 |
+
|
79 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
80 |
+
x = x.reshape([B, N, C])
|
81 |
+
|
82 |
+
x = self.proj(x)
|
83 |
+
x = self.proj_drop(x)
|
84 |
+
return x
|
unidepth/models/backbones/metadinov2/block.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Any, Callable, Dict, List, Tuple
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
logger = logging.getLogger("dinov2")
|
23 |
+
|
24 |
+
|
25 |
+
try:
|
26 |
+
from xformers.ops import fmha, index_select_cat, scaled_index_add
|
27 |
+
|
28 |
+
XFORMERS_AVAILABLE = True
|
29 |
+
except ImportError:
|
30 |
+
logger.warning("xFormers not available")
|
31 |
+
XFORMERS_AVAILABLE = False
|
32 |
+
|
33 |
+
|
34 |
+
class Block(nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
dim: int,
|
38 |
+
num_heads: int,
|
39 |
+
mlp_ratio: float = 4.0,
|
40 |
+
qkv_bias: bool = False,
|
41 |
+
proj_bias: bool = True,
|
42 |
+
ffn_bias: bool = True,
|
43 |
+
drop: float = 0.0,
|
44 |
+
attn_drop: float = 0.0,
|
45 |
+
init_values=None,
|
46 |
+
drop_path: float = 0.0,
|
47 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
48 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
49 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
50 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
54 |
+
self.norm1 = norm_layer(dim)
|
55 |
+
self.attn = attn_class(
|
56 |
+
dim,
|
57 |
+
num_heads=num_heads,
|
58 |
+
qkv_bias=qkv_bias,
|
59 |
+
proj_bias=proj_bias,
|
60 |
+
attn_drop=attn_drop,
|
61 |
+
proj_drop=drop,
|
62 |
+
)
|
63 |
+
self.ls1 = (
|
64 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
65 |
+
)
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = (
|
78 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
79 |
+
)
|
80 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
81 |
+
|
82 |
+
self.sample_drop_ratio = drop_path
|
83 |
+
|
84 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
85 |
+
def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
86 |
+
return self.ls1(self.attn(self.norm1(x)))
|
87 |
+
|
88 |
+
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
89 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
90 |
+
|
91 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
92 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
93 |
+
x = drop_add_residual_stochastic_depth(
|
94 |
+
x,
|
95 |
+
residual_func=attn_residual_func,
|
96 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
97 |
+
)
|
98 |
+
x = drop_add_residual_stochastic_depth(
|
99 |
+
x,
|
100 |
+
residual_func=ffn_residual_func,
|
101 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
102 |
+
)
|
103 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
104 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
105 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
106 |
+
else:
|
107 |
+
x = x + attn_residual_func(x)
|
108 |
+
x = x + ffn_residual_func(x)
|
109 |
+
return x
|
110 |
+
|
111 |
+
|
112 |
+
def drop_add_residual_stochastic_depth(
|
113 |
+
x: torch.Tensor,
|
114 |
+
residual_func: Callable[[torch.Tensor], torch.Tensor],
|
115 |
+
sample_drop_ratio: float = 0.0,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
# 1) extract subset using permutation
|
118 |
+
b, n, d = x.shape
|
119 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
120 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
121 |
+
x_subset = x[brange]
|
122 |
+
|
123 |
+
# 2) apply residual_func to get residual
|
124 |
+
residual = residual_func(x_subset)
|
125 |
+
|
126 |
+
x_flat = x.flatten(1)
|
127 |
+
residual = residual.flatten(1)
|
128 |
+
|
129 |
+
residual_scale_factor = b / sample_subset_size
|
130 |
+
|
131 |
+
# 3) add the residual
|
132 |
+
x_plus_residual = torch.index_add(
|
133 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
134 |
+
)
|
135 |
+
return x_plus_residual.view_as(x)
|
136 |
+
|
137 |
+
|
138 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
139 |
+
b, n, d = x.shape
|
140 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
141 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
142 |
+
residual_scale_factor = b / sample_subset_size
|
143 |
+
return brange, residual_scale_factor
|
144 |
+
|
145 |
+
|
146 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
147 |
+
if scaling_vector is None:
|
148 |
+
x_flat = x.flatten(1)
|
149 |
+
residual = residual.flatten(1)
|
150 |
+
x_plus_residual = torch.index_add(
|
151 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
x_plus_residual = scaled_index_add(
|
155 |
+
x,
|
156 |
+
brange,
|
157 |
+
residual.to(dtype=x.dtype),
|
158 |
+
scaling=scaling_vector,
|
159 |
+
alpha=residual_scale_factor,
|
160 |
+
)
|
161 |
+
return x_plus_residual
|
162 |
+
|
163 |
+
|
164 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
165 |
+
|
166 |
+
|
167 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
168 |
+
"""
|
169 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
170 |
+
"""
|
171 |
+
batch_sizes = (
|
172 |
+
[b.shape[0] for b in branges]
|
173 |
+
if branges is not None
|
174 |
+
else [x.shape[0] for x in x_list]
|
175 |
+
)
|
176 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
177 |
+
if all_shapes not in attn_bias_cache.keys():
|
178 |
+
seqlens = []
|
179 |
+
for b, x in zip(batch_sizes, x_list):
|
180 |
+
for _ in range(b):
|
181 |
+
seqlens.append(x.shape[1])
|
182 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
183 |
+
attn_bias._batch_sizes = batch_sizes
|
184 |
+
attn_bias_cache[all_shapes] = attn_bias
|
185 |
+
|
186 |
+
if branges is not None:
|
187 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
188 |
+
1, -1, x_list[0].shape[-1]
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
192 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
193 |
+
|
194 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
195 |
+
|
196 |
+
|
197 |
+
def drop_add_residual_stochastic_depth_list(
|
198 |
+
x_list: List[torch.Tensor],
|
199 |
+
residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
|
200 |
+
sample_drop_ratio: float = 0.0,
|
201 |
+
scaling_vector=None,
|
202 |
+
) -> torch.Tensor:
|
203 |
+
# 1) generate random set of indices for dropping samples in the batch
|
204 |
+
branges_scales = [
|
205 |
+
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
206 |
+
]
|
207 |
+
branges = [s[0] for s in branges_scales]
|
208 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
209 |
+
|
210 |
+
# 2) get attention bias and index+concat the tensors
|
211 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
212 |
+
|
213 |
+
# 3) apply residual_func to get residual, and split the result
|
214 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
215 |
+
|
216 |
+
outputs = []
|
217 |
+
for x, brange, residual, residual_scale_factor in zip(
|
218 |
+
x_list, branges, residual_list, residual_scale_factors
|
219 |
+
):
|
220 |
+
outputs.append(
|
221 |
+
add_residual(
|
222 |
+
x, brange, residual, residual_scale_factor, scaling_vector
|
223 |
+
).view_as(x)
|
224 |
+
)
|
225 |
+
return outputs
|
226 |
+
|
227 |
+
|
228 |
+
class NestedTensorBlock(Block):
|
229 |
+
def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
230 |
+
"""
|
231 |
+
x_list contains a list of tensors to nest together and run
|
232 |
+
"""
|
233 |
+
assert isinstance(self.attn, MemEffAttention)
|
234 |
+
|
235 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
236 |
+
|
237 |
+
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
238 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
239 |
+
|
240 |
+
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
241 |
+
return self.mlp(self.norm2(x))
|
242 |
+
|
243 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
244 |
+
x_list,
|
245 |
+
residual_func=attn_residual_func,
|
246 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
247 |
+
scaling_vector=(
|
248 |
+
self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
|
249 |
+
),
|
250 |
+
)
|
251 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
252 |
+
x_list,
|
253 |
+
residual_func=ffn_residual_func,
|
254 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
255 |
+
scaling_vector=(
|
256 |
+
self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
|
257 |
+
),
|
258 |
+
)
|
259 |
+
return x_list
|
260 |
+
else:
|
261 |
+
|
262 |
+
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
263 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
264 |
+
|
265 |
+
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
266 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
267 |
+
|
268 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
269 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
270 |
+
x = x + ffn_residual_func(x)
|
271 |
+
return attn_bias.split(x)
|
272 |
+
|
273 |
+
def forward(self, x_or_x_list):
|
274 |
+
if isinstance(x_or_x_list, torch.Tensor):
|
275 |
+
return super(NestedTensorBlock, self).forward(x_or_x_list)
|
276 |
+
elif isinstance(x_or_x_list, list):
|
277 |
+
assert (
|
278 |
+
XFORMERS_AVAILABLE
|
279 |
+
), "Please install xFormers for nested tensors usage"
|
280 |
+
return self.forward_nested(x_or_x_list)
|
281 |
+
else:
|
282 |
+
raise AssertionError
|
unidepth/models/backbones/metadinov2/dino_head.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn.init import trunc_normal_
|
10 |
+
from torch.nn.utils import weight_norm
|
11 |
+
|
12 |
+
|
13 |
+
class DINOHead(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_dim,
|
17 |
+
out_dim,
|
18 |
+
use_bn=False,
|
19 |
+
nlayers=3,
|
20 |
+
hidden_dim=2048,
|
21 |
+
bottleneck_dim=256,
|
22 |
+
mlp_bias=True,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
nlayers = max(nlayers, 1)
|
26 |
+
self.mlp = _build_mlp(
|
27 |
+
nlayers,
|
28 |
+
in_dim,
|
29 |
+
bottleneck_dim,
|
30 |
+
hidden_dim=hidden_dim,
|
31 |
+
use_bn=use_bn,
|
32 |
+
bias=mlp_bias,
|
33 |
+
)
|
34 |
+
self.apply(self._init_weights)
|
35 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
36 |
+
self.last_layer.weight_g.data.fill_(1)
|
37 |
+
|
38 |
+
def _init_weights(self, m):
|
39 |
+
if isinstance(m, nn.Linear):
|
40 |
+
trunc_normal_(m.weight, std=0.02)
|
41 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
42 |
+
nn.init.constant_(m.bias, 0)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.mlp(x)
|
46 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
47 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
48 |
+
x = self.last_layer(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
def _build_mlp(
|
53 |
+
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
54 |
+
):
|
55 |
+
if nlayers == 1:
|
56 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
57 |
+
else:
|
58 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
59 |
+
if use_bn:
|
60 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
61 |
+
layers.append(nn.GELU())
|
62 |
+
for _ in range(nlayers - 2):
|
63 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
64 |
+
if use_bn:
|
65 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
66 |
+
layers.append(nn.GELU())
|
67 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
68 |
+
return nn.Sequential(*layers)
|
unidepth/models/backbones/metadinov2/drop_path.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (
|
20 |
+
x.ndim - 1
|
21 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
22 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
23 |
+
if keep_prob > 0.0:
|
24 |
+
random_tensor.div_(keep_prob)
|
25 |
+
output = x * random_tensor
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
class DropPath(nn.Module):
|
30 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
31 |
+
|
32 |
+
def __init__(self, drop_prob=None):
|
33 |
+
super(DropPath, self).__init__()
|
34 |
+
self.drop_prob = drop_prob
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return drop_path(x, self.drop_prob, self.training)
|
unidepth/models/backbones/metadinov2/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|