qianyuchen commited on
Commit
649ea66
1 Parent(s): 45387f9

Create resampler.py

Browse files
Files changed (1) hide show
  1. resampler.py +0 -163
resampler.py CHANGED
@@ -1,163 +0,0 @@
1
- from functools import partial
2
- import numpy as np
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn.init import trunc_normal_
7
-
8
- def get_2d_sincos_pos_embed(embed_dim, image_size):
9
- """
10
- image_size: image_size or (image_height, image_width)
11
- return:
12
- pos_embed: [image_height, image_width, embed_dim]
13
- """
14
- if isinstance(image_size, int):
15
- grid_h_size, grid_w_size = image_size, image_size
16
- else:
17
- grid_h_size, grid_w_size = image_size[0], image_size[1]
18
-
19
- grid_h = np.arange(grid_h_size, dtype=np.float32)
20
- grid_w = np.arange(grid_w_size, dtype=np.float32)
21
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
22
- grid = np.stack(grid, axis=0)
23
-
24
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
25
- return pos_embed
26
-
27
-
28
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
29
- assert embed_dim % 2 == 0
30
-
31
- # use half of dimensions to encode grid_h
32
- emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2)
33
- emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2)
34
-
35
- emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
36
- return emb
37
-
38
-
39
- def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
40
- """
41
- embed_dim: output dimension for each position
42
- pos: a list of positions to be encoded: size (H, W)
43
- out: (H, W, D)
44
- """
45
- assert embed_dim % 2 == 0
46
- omega = np.arange(embed_dim // 2, dtype=np.float32)
47
- omega /= embed_dim / 2.
48
- omega = 1. / 10000 ** omega # (D/2,)
49
-
50
- out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
51
-
52
- emb_sin = np.sin(out) # (H, W, D/2)
53
- emb_cos = np.cos(out) # (H, W, D/2)
54
-
55
- emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
56
- return emb
57
-
58
-
59
- class Resampler(nn.Module):
60
- """
61
- A 2D perceiver-resampler network with one cross attention layers by
62
- given learnable queries and 2d sincos pos_emb
63
- Outputs:
64
- A tensor with the shape of (batch_size, num_queries, embed_dim)
65
- """
66
-
67
- def __init__(
68
- self,
69
- num_queries,
70
- embed_dim,
71
- num_heads,
72
- kv_dim=None,
73
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
74
- adaptive=False,
75
- max_size=(70, 70),
76
- ):
77
- super().__init__()
78
- self.num_queries = num_queries
79
- self.embed_dim = embed_dim
80
- self.num_heads = num_heads
81
- self.adaptive = adaptive
82
- self.max_size = max_size
83
-
84
- self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
85
- trunc_normal_(self.query, std=.02)
86
-
87
- if kv_dim is not None and kv_dim != embed_dim:
88
- self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
89
- else:
90
- self.kv_proj = nn.Identity()
91
-
92
- self.attn = nn.MultiheadAttention(embed_dim, num_heads)
93
- self.ln_q = norm_layer(embed_dim)
94
- self.ln_kv = norm_layer(embed_dim)
95
-
96
- self.ln_post = norm_layer(embed_dim)
97
- self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
98
-
99
- self._set_2d_pos_cache(self.max_size)
100
- self.apply(self._init_weights)
101
-
102
- def _set_2d_pos_cache(self, max_size, device='cpu'):
103
- pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
104
- self.register_buffer("pos_embed", pos_embed, persistent=False)
105
-
106
- def _adjust_pos_cache(self, tgt_sizes, device):
107
- max_h = torch.max(tgt_sizes[:, 0])
108
- max_w = torch.max(tgt_sizes[:, 1])
109
- if max_h > self.max_size[0] or max_w > self.max_size[1]:
110
- self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
111
- self._set_2d_pos_cache(self.max_size, device)
112
-
113
- def _init_weights(self, m):
114
- if isinstance(m, nn.Linear):
115
- trunc_normal_(m.weight, std=.02)
116
- if isinstance(m, nn.Linear) and m.bias is not None:
117
- nn.init.constant_(m.bias, 0)
118
- elif isinstance(m, nn.LayerNorm):
119
- nn.init.constant_(m.bias, 0)
120
- nn.init.constant_(m.weight, 1.0)
121
-
122
- def forward(self, x, tgt_sizes=None):
123
- assert x.shape[0] == tgt_sizes.shape[0]
124
- bs = x.shape[0]
125
-
126
- device = x.device
127
- dtype = x.dtype
128
-
129
- patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
130
-
131
- self._adjust_pos_cache(tgt_sizes, device=device)
132
-
133
- max_patch_len = torch.max(patch_len)
134
- key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
135
-
136
- pos_embed = []
137
- for i in range(bs):
138
- tgt_h, tgt_w = tgt_sizes[i]
139
- pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
140
- key_padding_mask[i, patch_len[i]:] = True
141
-
142
- pos_embed = torch.nn.utils.rnn.pad_sequence(
143
- pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
144
-
145
- x = self.kv_proj(x) # B * L * D
146
- x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
147
-
148
- q = self.ln_q(self.query) # Q * D
149
-
150
- out = self.attn(
151
- self._repeat(q, bs), # Q * B * D
152
- x + pos_embed, # L * B * D + L * B * D
153
- x,
154
- key_padding_mask=key_padding_mask)[0]
155
- # out: Q * B * D
156
- x = out.permute(1, 0, 2) # B * Q * D
157
-
158
- x = self.ln_post(x)
159
- x = x @ self.proj
160
- return x
161
-
162
- def _repeat(self, query, N: int):
163
- return query.unsqueeze(1).repeat(1, N, 1)