File size: 2,853 Bytes
54a7220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn
from einops import rearrange


class FourierEmbedder(object):
    def __init__(self, num_freqs=64, temperature=100):
        self.num_freqs = num_freqs
        self.temperature = temperature
        self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)  

    @ torch.no_grad()
    def __call__(self, x, cat_dim=-1):
        "x: arbitrary shape of tensor. dim: cat dim"
        out = []
        for freq in self.freq_bands:
            out.append(torch.sin(freq*x))
            out.append(torch.cos(freq*x))
        return torch.cat(out, cat_dim)


class FourierColorEmbedder(nn.Module):
    def __init__(self, in_dim=180, out_dim=768, num_tokens=4, fourier_freqs=4, temperature=100, scale=100):
        super().__init__()
        self.in_dim = in_dim  
        self.out_dim = out_dim
        self.fourier_freqs = fourier_freqs
        self.num_tokens = num_tokens

        self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs, temperature=temperature)
        self.in_dim *= (fourier_freqs * 2)
        self.mlp = nn.Sequential(
            nn.Linear(self.in_dim, 512),
            nn.LayerNorm(512),
            nn.SiLU(),
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            nn.SiLU(),
            nn.Linear(512, out_dim*self.num_tokens),
        )

        self.null_features = torch.nn.Parameter(torch.zeros([self.in_dim]))
        self.scale = scale

    def forward(self, x, mask=None):
        if x.ndim == 3:
            assert x.size(1) == 1
            x = x.squeeze(1)
        bs = x.shape[0]
        if mask is None:
            mask = torch.ones(bs, 1, device=x.device)
        x = self.fourier_embedder(x * self.scale) 
        x = mask * x + (1-mask) * self.null_features.view(1,-1)
        x = self.mlp(x).view(bs, self.num_tokens, self.out_dim)  # B*1*C
        return x


class GlobalAdapter(nn.Module):

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, context_tokens=4,
            color_in_dim=180, color_num_tokens=4, color_fourier_freqs=4, color_temperature=100, color_scale=100):

        super().__init__()

        self.cross_attention_dim = cross_attention_dim
        self.context_tokens = context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
        self.color_embed = FourierColorEmbedder(color_in_dim, cross_attention_dim, color_num_tokens, color_fourier_freqs, color_temperature, color_scale)

    def forward(self, x, x_color, *args, **kwargs):
        context_tokens = self.proj(x).reshape(-1, self.context_tokens, self.cross_attention_dim)
        context_tokens = self.norm(context_tokens)
        color_tokens = self.color_embed(x_color)
        return context_tokens, color_tokens