File size: 4,944 Bytes
414b431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.nn as nn

from functools import partial
from timm.models.vision_transformer import Block

# 3D positional encoding, from https://github.com/bmild/nerf.
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn,
                                 freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def get_embedder(posenc_res, input_dims=3):
    embed_kwargs = {
        'include_input': True,
        'input_dims': input_dims,
        'max_freq_log2': posenc_res-1,
        'num_freqs': posenc_res,
        'log_sampling': True,
        'periodic_fns': [torch.sin, torch.cos],
    }

    embedder_obj = Embedder(**embed_kwargs)
    def embed(x, eo=embedder_obj): return eo.embed(x)
    return embed, embedder_obj.out_dim

class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class Bottleneck_Linear(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.linear1 = nn.Linear(n_channels, n_channels)
        self.norm = nn.LayerNorm(n_channels)
        self.linear2 = nn.Linear(n_channels, n_channels)
        self.gelu = nn.GELU()
        
    def forward(self, x):
        x = x + self.linear2(self.gelu(self.linear1(self.norm(x))))
        return x

class Bottleneck_Conv(nn.Module):
    def __init__(self, n_channels, kernel_size=1):
        super().__init__()
        self.linear1 = nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
        self.bn1 = nn.BatchNorm2d(n_channels)
        self.linear2 = nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
        self.bn2 = nn.BatchNorm2d(n_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        assert len(x.shape) in [2, 4]
        input_dims = len(x.shape)
        if input_dims == 2:
            x = x.unsqueeze(-1).unsqueeze(-1)
        residual = x
        out = self.linear1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        if input_dims == 2:
            out = out.squeeze(-1).squeeze(-1)
        return out

class CLIPFusionBlock_Concat(nn.Module):
    """ 
    Fuse clip and rgb embeddings via concat-proj
    """
    def __init__(self, n_channels=512, n_layers=1, act=True):
        super().__init__()
        proj = [Bottleneck_Linear(2 * n_channels) for _ in range(n_layers)]
        proj.append(nn.Linear(2 * n_channels, n_channels))
        if act: proj.append(nn.GELU())
        self.proj = nn.Sequential(*proj)
    
    def forward(self, sem_latent, clip_latent):
        """ 
        sem_latent: [B, N, C]
        clip_latent: [B, C]
        """
        # [B, N, 2C]
        latent_concat = torch.cat([sem_latent, clip_latent.unsqueeze(1).expand_as(sem_latent)], dim=-1)
        # [B, N, C]
        latent = self.proj(latent_concat)
        return latent

class CLIPFusionBlock_Attn(nn.Module):
    """ 
    Fuse geometric and semantic embeddings via multi-layer MHA blocks
    """
    def __init__(self, n_channels=512, n_layers=1, act=True):
        super().__init__()
        self.attn_blocks = nn.ModuleList(
            [Block(
                n_channels, 8, 4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1
            ) for _ in range(n_layers)]
        )
        if act: self.attn_blocks.append(nn.GELU())
    
    def forward(self, sem_latent, clip_latent):
        """ 
        sem_latent: [B, N, C]
        clip_latent: [B, C]
        """
        # [B, 1+N, C], clip first
        latent = torch.cat([clip_latent.unsqueeze(1), sem_latent], dim=1)
        for attn_block in self.attn_blocks:
            latent = attn_block(latent)
        # [B, N, C]
        return latent[:, 1:, :]