ybelkada commited on
Commit
7708d0d
1 Parent(s): 5b07818

add first files

Browse files
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ from torchvision import transforms
9
+
10
+ from transformers import AutoConfig, AutoModel
11
+ from transformers import AutoModel
12
+
13
+ from focusondepth.model_config import FocusOnDepthConfig
14
+ from focusondepth.model_definition import FocusOnDepth
15
+
16
+ AutoConfig.register("focusondepth", FocusOnDepthConfig)
17
+ AutoModel.register(FocusOnDepthConfig, FocusOnDepth)
18
+
19
+ original_image_cache = {}
20
+ transform = transforms.Compose([
21
+ transforms.Resize((384, 384)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
24
+ ])
25
+ model = AutoModel.from_pretrained('ybelkada/focusondepth', trust_remote_code=True)
26
+ model.load_state_dict(torch.load('./focusondepth/FocusOnDepth_vit_base_patch16_384.p', map_location=torch.device('cpu'))['model_state_dict'])
27
+
28
+ @torch.no_grad()
29
+ def inference(input_image):
30
+ global model, transform
31
+
32
+ model.eval()
33
+ input_image = Image.fromarray(input_image)
34
+ original_size = input_image.size
35
+ tensor_image = transform(input_image)
36
+
37
+ depth, segmentation = model(tensor_image.unsqueeze(0))
38
+ depth = 1-depth
39
+
40
+ depth = transforms.ToPILImage()(depth[0, :])
41
+ segmentation = transforms.ToPILImage()(segmentation.argmax(dim=1).float())
42
+
43
+ return [depth.resize(original_size, resample=Image.BICUBIC), segmentation.resize(original_size, resample=Image.NEAREST)]
44
+
45
+ iface = gr.Interface(
46
+ fn=inference,
47
+ inputs=gr.inputs.Image(label="Input Image"),
48
+ outputs = [
49
+ gr.outputs.Image(label="Depth Map:"),
50
+ gr.outputs.Image(label="Segmentation Map:"),
51
+ ],
52
+ )
53
+ iface.launch()
focusondepth/__init__.py ADDED
File without changes
focusondepth/fusion.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ResidualConvUnit(nn.Module):
6
+ def __init__(self, features):
7
+ super().__init__()
8
+
9
+ self.conv1 = nn.Conv2d(
10
+ features, features, kernel_size=3, stride=1, padding=1, bias=True)
11
+ self.conv2 = nn.Conv2d(
12
+ features, features, kernel_size=3, stride=1, padding=1, bias=True)
13
+ self.relu = nn.ReLU(inplace=True)
14
+
15
+ def forward(self, x):
16
+ """Forward pass.
17
+ Args:
18
+ x (tensor): input
19
+ Returns:
20
+ tensor: output
21
+ """
22
+ out = self.relu(x)
23
+ out = self.conv1(out)
24
+ out = self.relu(out)
25
+ out = self.conv2(out)
26
+ return out + x
27
+
28
+ class Fusion(nn.Module):
29
+ def __init__(self, resample_dim):
30
+ super(Fusion, self).__init__()
31
+ self.res_conv1 = ResidualConvUnit(resample_dim)
32
+ self.res_conv2 = ResidualConvUnit(resample_dim)
33
+
34
+ def forward(self, x, previous_stage=None):
35
+ if previous_stage == None:
36
+ previous_stage = torch.zeros_like(x)
37
+ output_stage1 = self.res_conv1(x)
38
+ output_stage1 += previous_stage
39
+ output_stage2 = self.res_conv2(output_stage1)
40
+ output_stage2 = nn.functional.interpolate(output_stage2, scale_factor=2, mode="bilinear", align_corners=True)
41
+ return output_stage2
focusondepth/head.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class Interpolate(nn.Module):
6
+ def __init__(self, scale_factor, mode, align_corners=False):
7
+ super(Interpolate, self).__init__()
8
+ self.interp = nn.functional.interpolate
9
+ self.scale_factor = scale_factor
10
+ self.mode = mode
11
+ self.align_corners = align_corners
12
+
13
+ def forward(self, x):
14
+ x = self.interp(
15
+ x,
16
+ scale_factor=self.scale_factor,
17
+ mode=self.mode,
18
+ align_corners=self.align_corners)
19
+ return x
20
+
21
+ class HeadDepth(nn.Module):
22
+ def __init__(self, features):
23
+ super(HeadDepth, self).__init__()
24
+ self.head = nn.Sequential(
25
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
26
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
27
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
28
+ nn.ReLU(),
29
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
30
+ # nn.ReLU()
31
+ nn.Sigmoid()
32
+ )
33
+ def forward(self, x):
34
+ x = self.head(x)
35
+ # x = (x - x.min())/(x.max()-x.min() + 1e-15)
36
+ return x
37
+
38
+ class HeadSeg(nn.Module):
39
+ def __init__(self, features, nclasses=2):
40
+ super(HeadSeg, self).__init__()
41
+ self.head = nn.Sequential(
42
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
43
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
44
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
45
+ nn.ReLU(),
46
+ nn.Conv2d(32, nclasses, kernel_size=1, stride=1, padding=0)
47
+ )
48
+ def forward(self, x):
49
+ x = self.head(x)
50
+ return x
focusondepth/model_config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class FocusOnDepthConfig(PretrainedConfig):
6
+ model_type = "focusondepth"
7
+
8
+ def __init__(
9
+ self,
10
+ image_size = (3, 384, 384),
11
+ patch_size = 16,
12
+ emb_dim = 768,
13
+ resample_dim = 256,
14
+ read = 'projection',
15
+ num_layers_encoder = 24,
16
+ hooks = [2, 5, 8, 11],
17
+ reassemble_s = [4, 8, 16, 32],
18
+ transformer_dropout= 0,
19
+ nclasses = 2,
20
+ type_ = "full",
21
+ model_timm = "vit_base_patch16_384",
22
+ **kwargs,
23
+ ):
24
+ if type_ not in ["full", "depth", "segmentation"]:
25
+ raise ValueError(f"`type_` must be 'full' or depth' or 'segmentation, got {type_}.")
26
+ if read not in ["ignore", "add", "projection"]:
27
+ raise ValueError(f"`read` must be '', 'ignore' or 'add' or 'projection, got {read}.")
28
+ if image_size[0] != 3 and image_size[1] != 384 and image_size[2] != 384:
29
+ raise ValueError(f"`image_size` must be 3, 384, 384, got {image_size}.")
30
+ if patch_size != 16:
31
+ raise ValueError(f"`patch_size` must be 16, got {patch_size}.")
32
+
33
+ self.image_size = image_size
34
+ self.patch_size = patch_size
35
+ self.emb_dim = emb_dim
36
+ self.resample_dim = resample_dim
37
+ self.read = read
38
+ self.num_layers_encoder = num_layers_encoder
39
+ self.hooks = hooks
40
+ self.reassemble_s = reassemble_s
41
+ self.transformer_dropout = transformer_dropout
42
+ self.nclasses = nclasses
43
+ self.type_ = type_
44
+ self.model_timm = model_timm
45
+ super().__init__(**kwargs)
focusondepth/model_definition.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ import timm
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from .model_config import FocusOnDepthConfig
7
+ from .reassemble import Reassemble
8
+ from .fusion import Fusion
9
+ from .head import HeadDepth, HeadSeg
10
+
11
+
12
+ class FocusOnDepth(PreTrainedModel):
13
+ config_class = FocusOnDepthConfig
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.transformer_encoders = timm.create_model(config.model_timm, pretrained=True)
18
+ self.type_ = config.type_
19
+
20
+ #Register hooks
21
+ self.activation = {}
22
+ self.hooks = config.hooks
23
+ self._get_layers_from_hooks(self.hooks)
24
+
25
+ #Reassembles Fusion
26
+ self.reassembles = []
27
+ self.fusions = []
28
+ for s in config.reassemble_s:
29
+ self.reassembles.append(Reassemble(config.image_size, config.read, config.patch_size, s, config.emb_dim, config.resample_dim))
30
+ self.fusions.append(Fusion(config.resample_dim))
31
+ self.reassembles = nn.ModuleList(self.reassembles)
32
+ self.fusions = nn.ModuleList(self.fusions)
33
+
34
+ #Head
35
+ if self.type_ == "full":
36
+ self.head_depth = HeadDepth(config.resample_dim)
37
+ self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
38
+ elif self.type_ == "depth":
39
+ self.head_depth = HeadDepth(config.resample_dim)
40
+ self.head_segmentation = None
41
+ else:
42
+ self.head_depth = None
43
+ self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
44
+
45
+ def forward(self, img):
46
+ _ = self.transformer_encoders(img)
47
+ previous_stage = None
48
+ for i in np.arange(len(self.fusions)-1, -1, -1):
49
+ hook_to_take = 't'+str(self.hooks[i])
50
+ activation_result = self.activation[hook_to_take]
51
+ reassemble_result = self.reassembles[i](activation_result)
52
+ fusion_result = self.fusions[i](reassemble_result, previous_stage)
53
+ previous_stage = fusion_result
54
+ out_depth = None
55
+ out_segmentation = None
56
+ if self.head_depth != None:
57
+ out_depth = self.head_depth(previous_stage)
58
+ if self.head_segmentation != None:
59
+ out_segmentation = self.head_segmentation(previous_stage)
60
+ return out_depth, out_segmentation
61
+
62
+ def _get_layers_from_hooks(self, hooks):
63
+ def get_activation(name):
64
+ def hook(model, input, output):
65
+ self.activation[name] = output
66
+ return hook
67
+ for h in hooks:
68
+ self.transformer_encoders.blocks[h].register_forward_hook(get_activation('t'+str(h)))
focusondepth/reassemble.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import Rearrange
6
+
7
+ class Read_ignore(nn.Module):
8
+ def __init__(self, start_index=1):
9
+ super(Read_ignore, self).__init__()
10
+ self.start_index = start_index
11
+
12
+ def forward(self, x):
13
+ return x[:, self.start_index:]
14
+
15
+
16
+ class Read_add(nn.Module):
17
+ def __init__(self, start_index=1):
18
+ super(Read_add, self).__init__()
19
+ self.start_index = start_index
20
+
21
+ def forward(self, x):
22
+ if self.start_index == 2:
23
+ readout = (x[:, 0] + x[:, 1]) / 2
24
+ else:
25
+ readout = x[:, 0]
26
+ return x[:, self.start_index :] + readout.unsqueeze(1)
27
+
28
+
29
+ class Read_projection(nn.Module):
30
+ def __init__(self, in_features, start_index=1):
31
+ super(Read_projection, self).__init__()
32
+ self.start_index = start_index
33
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
34
+
35
+ def forward(self, x):
36
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
37
+ features = torch.cat((x[:, self.start_index :], readout), -1)
38
+ return self.project(features)
39
+
40
+ class MyConvTranspose2d(nn.Module):
41
+ def __init__(self, conv, output_size):
42
+ super(MyConvTranspose2d, self).__init__()
43
+ self.output_size = output_size
44
+ self.conv = conv
45
+
46
+ def forward(self, x):
47
+ x = self.conv(x, output_size=self.output_size)
48
+ return x
49
+
50
+ class Resample(nn.Module):
51
+ def __init__(self, p, s, h, emb_dim, resample_dim):
52
+ super(Resample, self).__init__()
53
+ assert (s in [4, 8, 16, 32]), "s must be in [0.5, 4, 8, 16, 32]"
54
+ self.conv1 = nn.Conv2d(emb_dim, resample_dim, kernel_size=1, stride=1, padding=0)
55
+ if s == 4:
56
+ self.conv2 = nn.ConvTranspose2d(resample_dim,
57
+ resample_dim,
58
+ kernel_size=4,
59
+ stride=4,
60
+ padding=0,
61
+ bias=True,
62
+ dilation=1,
63
+ groups=1)
64
+ elif s == 8:
65
+ self.conv2 = nn.ConvTranspose2d(resample_dim,
66
+ resample_dim,
67
+ kernel_size=2,
68
+ stride=2,
69
+ padding=0,
70
+ bias=True,
71
+ dilation=1,
72
+ groups=1)
73
+ elif s == 16:
74
+ self.conv2 = nn.Identity()
75
+ else:
76
+ self.conv2 = nn.Conv2d(resample_dim, resample_dim, kernel_size=2,stride=2, padding=0, bias=True)
77
+
78
+ def forward(self, x):
79
+ x = self.conv1(x)
80
+ x = self.conv2(x)
81
+ return x
82
+
83
+ class Reassemble(nn.Module):
84
+ def __init__(self, image_size, read, p, s, emb_dim, resample_dim):
85
+ """
86
+ p = patch size
87
+ s = coefficient resample
88
+ emb_dim <=> D (in the paper)
89
+ resample_dim <=> ^D (in the paper)
90
+ read : {"ignore", "add", "projection"}
91
+ """
92
+ super(Reassemble, self).__init__()
93
+ channels, image_height, image_width = image_size
94
+
95
+ #Read
96
+ self.read = Read_ignore()
97
+ if read == 'add':
98
+ self.read = Read_add()
99
+ elif read == 'projection':
100
+ self.read = Read_projection(emb_dim)
101
+
102
+ #Concat after read
103
+ self.concat = Rearrange('b (h w) c -> b c h w',
104
+ c=emb_dim,
105
+ h=(image_height // p),
106
+ w=(image_width // p))
107
+
108
+ #Projection + Resample
109
+ self.resample = Resample(p, s, image_height, emb_dim, resample_dim)
110
+
111
+ def forward(self, x):
112
+ x = self.read(x)
113
+ x = self.concat(x)
114
+ x = self.resample(x)
115
+ return x