iakarshu commited on
Commit
15dee1b
1 Parent(s): 2f4acb8

Upload modeling.py

Browse files

Added modeling file

Files changed (1) hide show
  1. modeling.py +545 -0
modeling.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.models as models
6
+ from einops import rearrange
7
+ from torch import Tensor
8
+
9
+ class PositionalEncoding(nn.Module):
10
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
11
+ super().__init__()
12
+ self.dropout = nn.Dropout(p=dropout)
13
+ self.max_len = max_len
14
+ self.d_model = d_model
15
+ position = torch.arange(max_len).unsqueeze(1)
16
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
17
+ pe = torch.zeros(1, max_len, d_model)
18
+ pe[0, :, 0::2] = torch.sin(position * div_term)
19
+ pe[0, :, 1::2] = torch.cos(position * div_term)
20
+ self.register_buffer("pe", pe)
21
+
22
+
23
+ def forward(self) -> Tensor:
24
+ x = self.pe[0, : self.max_len]
25
+ return self.dropout(x).unsqueeze(0)
26
+
27
+
28
+ class ResNetFeatureExtractor(nn.Module):
29
+ def __init__(self, hidden_dim = 512):
30
+ super().__init__()
31
+
32
+ # Making the resnet 50 model, which was used in the docformer for the purpose of visual feature extraction
33
+
34
+ resnet50 = models.resnet50(pretrained=False)
35
+ modules = list(resnet50.children())[:-2]
36
+ self.resnet50 = nn.Sequential(*modules)
37
+
38
+ # Applying convolution and linear layer
39
+
40
+ self.conv1 = nn.Conv2d(2048, 768, 1)
41
+ self.relu1 = F.relu
42
+ self.linear1 = nn.Linear(192, hidden_dim)
43
+
44
+ def forward(self, x):
45
+ x = self.resnet50(x)
46
+ x = self.conv1(x)
47
+ x = self.relu1(x)
48
+ x = rearrange(x, "b e w h -> b e (w h)") # b -> batch, e -> embedding dim, w -> width, h -> height
49
+ x = self.linear1(x)
50
+ x = rearrange(x, "b e s -> b s e") # b -> batch, e -> embedding dim, s -> sequence length
51
+ return x
52
+
53
+ class DocFormerEmbeddings(nn.Module):
54
+ """Construct the embeddings from word, position and token_type embeddings."""
55
+
56
+ def __init__(self, config):
57
+ super(DocFormerEmbeddings, self).__init__()
58
+
59
+ self.config = config
60
+
61
+ self.position_embeddings_v = PositionalEncoding(
62
+ d_model=config["hidden_size"],
63
+ dropout=0.1,
64
+ max_len=config["max_position_embeddings"],
65
+ )
66
+
67
+ self.x_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
68
+ self.x_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
69
+ self.w_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
70
+ self.x_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
71
+ self.x_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
72
+ self.x_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
73
+ self.x_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
74
+ self.x_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
75
+
76
+ self.y_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
77
+ self.y_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
78
+ self.h_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
79
+ self.y_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
80
+ self.y_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
81
+ self.y_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
82
+ self.y_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
83
+ self.y_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
84
+
85
+ self.position_embeddings_t = PositionalEncoding(
86
+ d_model=config["hidden_size"],
87
+ dropout=0.1,
88
+ max_len=config["max_position_embeddings"],
89
+ )
90
+
91
+ self.x_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
92
+ self.x_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
93
+ self.w_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
94
+ self.x_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
95
+ self.x_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
96
+ self.x_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
97
+ self.x_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
98
+ self.x_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
99
+
100
+ self.y_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
101
+ self.y_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
102
+ self.h_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
103
+ self.y_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
104
+ self.y_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
105
+ self.y_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
106
+ self.y_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
107
+ self.y_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
108
+
109
+ self.LayerNorm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
110
+ self.dropout = nn.Dropout(config["hidden_dropout_prob"])
111
+
112
+
113
+
114
+ def forward(self, x_feature, y_feature):
115
+
116
+ """
117
+ Arguments:
118
+ x_features of shape, (batch size, seq_len, 8)
119
+ y_features of shape, (batch size, seq_len, 8)
120
+ Outputs:
121
+ (V-bar-s, T-bar-s) of shape (batch size, 512,768),(batch size, 512,768)
122
+ What are the features:
123
+ 0 -> top left x/y
124
+ 1 -> bottom right x/y
125
+ 2 -> width/height
126
+ 3 -> diff top left x/y
127
+ 4 -> diff bottom left x/y
128
+ 5 -> diff top right x/y
129
+ 6 -> diff bottom right x/y
130
+ 7 -> centroids diff x/y
131
+ """
132
+
133
+
134
+ batch, seq_len = x_feature.shape[:-1]
135
+ hidden_size = self.config["hidden_size"]
136
+ num_feat = x_feature.shape[-1]
137
+ sub_dim = hidden_size // num_feat
138
+
139
+ # Clamping and adding a bias for handling negative values
140
+ x_feature[:,:,3:] = torch.clamp(x_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
141
+ x_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]
142
+
143
+ y_feature[:,:,3:] = torch.clamp(y_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
144
+ y_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]
145
+
146
+ x_topleft_position_embeddings_v = self.x_topleft_position_embeddings_v(x_feature[:,:,0])
147
+ x_bottomright_position_embeddings_v = self.x_bottomright_position_embeddings_v(x_feature[:,:,1])
148
+ w_position_embeddings_v = self.w_position_embeddings_v(x_feature[:,:,2])
149
+ x_topleft_distance_to_prev_embeddings_v = self.x_topleft_distance_to_prev_embeddings_v(x_feature[:,:,3])
150
+ x_bottomleft_distance_to_prev_embeddings_v = self.x_bottomleft_distance_to_prev_embeddings_v(x_feature[:,:,4])
151
+ x_topright_distance_to_prev_embeddings_v = self.x_topright_distance_to_prev_embeddings_v(x_feature[:,:,5])
152
+ x_bottomright_distance_to_prev_embeddings_v = self.x_bottomright_distance_to_prev_embeddings_v(x_feature[:,:,6])
153
+ x_centroid_distance_to_prev_embeddings_v = self.x_centroid_distance_to_prev_embeddings_v(x_feature[:,:,7])
154
+
155
+ x_calculated_embedding_v = torch.cat(
156
+ [
157
+ x_topleft_position_embeddings_v,
158
+ x_bottomright_position_embeddings_v,
159
+ w_position_embeddings_v,
160
+ x_topleft_distance_to_prev_embeddings_v,
161
+ x_bottomleft_distance_to_prev_embeddings_v,
162
+ x_topright_distance_to_prev_embeddings_v,
163
+ x_bottomright_distance_to_prev_embeddings_v ,
164
+ x_centroid_distance_to_prev_embeddings_v
165
+ ],
166
+ dim = -1
167
+ )
168
+
169
+ y_topleft_position_embeddings_v = self.y_topleft_position_embeddings_v(y_feature[:,:,0])
170
+ y_bottomright_position_embeddings_v = self.y_bottomright_position_embeddings_v(y_feature[:,:,1])
171
+ h_position_embeddings_v = self.h_position_embeddings_v(y_feature[:,:,2])
172
+ y_topleft_distance_to_prev_embeddings_v = self.y_topleft_distance_to_prev_embeddings_v(y_feature[:,:,3])
173
+ y_bottomleft_distance_to_prev_embeddings_v = self.y_bottomleft_distance_to_prev_embeddings_v(y_feature[:,:,4])
174
+ y_topright_distance_to_prev_embeddings_v = self.y_topright_distance_to_prev_embeddings_v(y_feature[:,:,5])
175
+ y_bottomright_distance_to_prev_embeddings_v = self.y_bottomright_distance_to_prev_embeddings_v(y_feature[:,:,6])
176
+ y_centroid_distance_to_prev_embeddings_v = self.y_centroid_distance_to_prev_embeddings_v(y_feature[:,:,7])
177
+
178
+ x_calculated_embedding_v = torch.cat(
179
+ [
180
+ x_topleft_position_embeddings_v,
181
+ x_bottomright_position_embeddings_v,
182
+ w_position_embeddings_v,
183
+ x_topleft_distance_to_prev_embeddings_v,
184
+ x_bottomleft_distance_to_prev_embeddings_v,
185
+ x_topright_distance_to_prev_embeddings_v,
186
+ x_bottomright_distance_to_prev_embeddings_v ,
187
+ x_centroid_distance_to_prev_embeddings_v
188
+ ],
189
+ dim = -1
190
+ )
191
+
192
+ y_calculated_embedding_v = torch.cat(
193
+ [
194
+ y_topleft_position_embeddings_v,
195
+ y_bottomright_position_embeddings_v,
196
+ h_position_embeddings_v,
197
+ y_topleft_distance_to_prev_embeddings_v,
198
+ y_bottomleft_distance_to_prev_embeddings_v,
199
+ y_topright_distance_to_prev_embeddings_v,
200
+ y_bottomright_distance_to_prev_embeddings_v ,
201
+ y_centroid_distance_to_prev_embeddings_v
202
+ ],
203
+ dim = -1
204
+ )
205
+
206
+ v_bar_s = x_calculated_embedding_v + y_calculated_embedding_v + self.position_embeddings_v()
207
+
208
+
209
+
210
+ x_topleft_position_embeddings_t = self.x_topleft_position_embeddings_t(x_feature[:,:,0])
211
+ x_bottomright_position_embeddings_t = self.x_bottomright_position_embeddings_t(x_feature[:,:,1])
212
+ w_position_embeddings_t = self.w_position_embeddings_t(x_feature[:,:,2])
213
+ x_topleft_distance_to_prev_embeddings_t = self.x_topleft_distance_to_prev_embeddings_t(x_feature[:,:,3])
214
+ x_bottomleft_distance_to_prev_embeddings_t = self.x_bottomleft_distance_to_prev_embeddings_t(x_feature[:,:,4])
215
+ x_topright_distance_to_prev_embeddings_t = self.x_topright_distance_to_prev_embeddings_t(x_feature[:,:,5])
216
+ x_bottomright_distance_to_prev_embeddings_t = self.x_bottomright_distance_to_prev_embeddings_t(x_feature[:,:,6])
217
+ x_centroid_distance_to_prev_embeddings_t = self.x_centroid_distance_to_prev_embeddings_t(x_feature[:,:,7])
218
+
219
+ x_calculated_embedding_t = torch.cat(
220
+ [
221
+ x_topleft_position_embeddings_t,
222
+ x_bottomright_position_embeddings_t,
223
+ w_position_embeddings_t,
224
+ x_topleft_distance_to_prev_embeddings_t,
225
+ x_bottomleft_distance_to_prev_embeddings_t,
226
+ x_topright_distance_to_prev_embeddings_t,
227
+ x_bottomright_distance_to_prev_embeddings_t ,
228
+ x_centroid_distance_to_prev_embeddings_t
229
+ ],
230
+ dim = -1
231
+ )
232
+
233
+ y_topleft_position_embeddings_t = self.y_topleft_position_embeddings_t(y_feature[:,:,0])
234
+ y_bottomright_position_embeddings_t = self.y_bottomright_position_embeddings_t(y_feature[:,:,1])
235
+ h_position_embeddings_t = self.h_position_embeddings_t(y_feature[:,:,2])
236
+ y_topleft_distance_to_prev_embeddings_t = self.y_topleft_distance_to_prev_embeddings_t(y_feature[:,:,3])
237
+ y_bottomleft_distance_to_prev_embeddings_t = self.y_bottomleft_distance_to_prev_embeddings_t(y_feature[:,:,4])
238
+ y_topright_distance_to_prev_embeddings_t = self.y_topright_distance_to_prev_embeddings_t(y_feature[:,:,5])
239
+ y_bottomright_distance_to_prev_embeddings_t = self.y_bottomright_distance_to_prev_embeddings_t(y_feature[:,:,6])
240
+ y_centroid_distance_to_prev_embeddings_t = self.y_centroid_distance_to_prev_embeddings_t(y_feature[:,:,7])
241
+
242
+ x_calculated_embedding_t = torch.cat(
243
+ [
244
+ x_topleft_position_embeddings_t,
245
+ x_bottomright_position_embeddings_t,
246
+ w_position_embeddings_t,
247
+ x_topleft_distance_to_prev_embeddings_t,
248
+ x_bottomleft_distance_to_prev_embeddings_t,
249
+ x_topright_distance_to_prev_embeddings_t,
250
+ x_bottomright_distance_to_prev_embeddings_t ,
251
+ x_centroid_distance_to_prev_embeddings_t
252
+ ],
253
+ dim = -1
254
+ )
255
+
256
+ y_calculated_embedding_t = torch.cat(
257
+ [
258
+ y_topleft_position_embeddings_t,
259
+ y_bottomright_position_embeddings_t,
260
+ h_position_embeddings_t,
261
+ y_topleft_distance_to_prev_embeddings_t,
262
+ y_bottomleft_distance_to_prev_embeddings_t,
263
+ y_topright_distance_to_prev_embeddings_t,
264
+ y_bottomright_distance_to_prev_embeddings_t ,
265
+ y_centroid_distance_to_prev_embeddings_t
266
+ ],
267
+ dim = -1
268
+ )
269
+
270
+ t_bar_s = x_calculated_embedding_t + y_calculated_embedding_t + self.position_embeddings_t()
271
+
272
+ return v_bar_s, t_bar_s
273
+
274
+
275
+
276
+ # fmt: off
277
+ class PreNorm(nn.Module):
278
+ def __init__(self, dim, fn):
279
+ # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
280
+ super().__init__()
281
+ self.norm = nn.LayerNorm(dim)
282
+ self.fn = fn
283
+
284
+ def forward(self, x, **kwargs):
285
+ return self.fn(self.norm(x), **kwargs)
286
+
287
+
288
+ class PreNormAttn(nn.Module):
289
+ def __init__(self, dim, fn):
290
+ # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
291
+ super().__init__()
292
+
293
+ self.norm_t_bar = nn.LayerNorm(dim)
294
+ self.norm_v_bar = nn.LayerNorm(dim)
295
+ self.norm_t_bar_s = nn.LayerNorm(dim)
296
+ self.norm_v_bar_s = nn.LayerNorm(dim)
297
+ self.fn = fn
298
+
299
+ def forward(self, t_bar, v_bar, t_bar_s, v_bar_s, **kwargs):
300
+ return self.fn(self.norm_t_bar(t_bar),
301
+ self.norm_v_bar(v_bar),
302
+ self.norm_t_bar_s(t_bar_s),
303
+ self.norm_v_bar_s(v_bar_s), **kwargs)
304
+
305
+
306
+ class FeedForward(nn.Module):
307
+ def __init__(self, dim, hidden_dim, dropout=0.):
308
+ super().__init__()
309
+ self.net = nn.Sequential(
310
+ nn.Linear(dim, hidden_dim),
311
+ nn.GELU(),
312
+ nn.Dropout(dropout),
313
+ nn.Linear(hidden_dim, dim),
314
+ nn.Dropout(dropout)
315
+ )
316
+
317
+ def forward(self, x):
318
+ return self.net(x)
319
+
320
+
321
+ class RelativePosition(nn.Module):
322
+
323
+ def __init__(self, num_units, max_relative_position, max_seq_length):
324
+ super().__init__()
325
+ self.num_units = num_units
326
+ self.max_relative_position = max_relative_position
327
+ self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
328
+ self.max_length = max_seq_length
329
+ range_vec_q = torch.arange(max_seq_length)
330
+ range_vec_k = torch.arange(max_seq_length)
331
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
332
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
333
+ final_mat = distance_mat_clipped + self.max_relative_position
334
+ self.final_mat = torch.LongTensor(final_mat)
335
+ nn.init.xavier_uniform_(self.embeddings_table)
336
+
337
+ def forward(self, length_q, length_k):
338
+ embeddings = self.embeddings_table[self.final_mat[:length_q, :length_k]]
339
+ return embeddings
340
+
341
+
342
+ class MultiModalAttentionLayer(nn.Module):
343
+ def __init__(self, embed_dim, n_heads, max_relative_position, max_seq_length, dropout):
344
+ super().__init__()
345
+ assert embed_dim % n_heads == 0
346
+
347
+ self.embed_dim = embed_dim
348
+ self.n_heads = n_heads
349
+ self.head_dim = embed_dim // n_heads
350
+
351
+ self.relative_positions_text = RelativePosition(self.head_dim, max_relative_position, max_seq_length)
352
+ self.relative_positions_img = RelativePosition(self.head_dim, max_relative_position, max_seq_length)
353
+
354
+ # text qkv embeddings
355
+ self.fc_k_text = nn.Linear(embed_dim, embed_dim)
356
+ self.fc_q_text = nn.Linear(embed_dim, embed_dim)
357
+ self.fc_v_text = nn.Linear(embed_dim, embed_dim)
358
+
359
+ # image qkv embeddings
360
+ self.fc_k_img = nn.Linear(embed_dim, embed_dim)
361
+ self.fc_q_img = nn.Linear(embed_dim, embed_dim)
362
+ self.fc_v_img = nn.Linear(embed_dim, embed_dim)
363
+
364
+ # spatial qk embeddings (shared for visual and text)
365
+ self.fc_k_spatial = nn.Linear(embed_dim, embed_dim)
366
+ self.fc_q_spatial = nn.Linear(embed_dim, embed_dim)
367
+
368
+ self.dropout = nn.Dropout(dropout)
369
+
370
+ self.to_out = nn.Sequential(
371
+ nn.Linear(embed_dim, embed_dim),
372
+ nn.Dropout(dropout)
373
+ )
374
+ self.scale = embed_dim**0.5
375
+
376
+ def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):
377
+ text_feat = text_feat
378
+ img_feat = img_feat
379
+ text_spatial_feat = text_spatial_feat
380
+ img_spatial_feat = img_spatial_feat
381
+ seq_length = text_feat.shape[1]
382
+
383
+ # self attention of text
384
+ # b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim.
385
+ key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
386
+ query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads)
387
+ value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
388
+ dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh)
389
+ dots_text = dots_text/ self.scale
390
+
391
+ # 1D relative positions (query, key)
392
+ rel_pos_embed_text = self.relative_positions_text(seq_length, seq_length)
393
+ rel_pos_key_text = torch.einsum('bhrd,lrd->bhlr', key_text_nh, rel_pos_embed_text)
394
+ rel_pos_query_text = torch.einsum('bhld,lrd->bhlr', query_text_nh, rel_pos_embed_text)
395
+
396
+ # shared spatial <-> text hidden features
397
+ key_spatial_text = self.fc_k_spatial(text_spatial_feat)
398
+ query_spatial_text = self.fc_q_spatial(text_spatial_feat)
399
+ key_spatial_text_nh = rearrange(key_spatial_text, 'b t (head k) -> head b t k', head=self.n_heads)
400
+ query_spatial_text_nh = rearrange(query_spatial_text, 'b l (head k) -> head b l k', head=self.n_heads)
401
+ dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh)
402
+ dots_text_spatial = dots_text_spatial/ self.scale
403
+
404
+ # Line 38 of pseudo-code
405
+ text_attn_scores = dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial
406
+
407
+ # self-attention of image
408
+ key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
409
+ query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads)
410
+ value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
411
+ dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh)
412
+ dots_img = dots_img/ self.scale
413
+
414
+ # 1D relative positions (query, key)
415
+ rel_pos_embed_img = self.relative_positions_img(seq_length, seq_length)
416
+ rel_pos_key_img = torch.einsum('bhrd,lrd->bhlr', key_img_nh, rel_pos_embed_text)
417
+ rel_pos_query_img = torch.einsum('bhld,lrd->bhlr', query_img_nh, rel_pos_embed_text)
418
+
419
+ # shared spatial <-> image features
420
+ key_spatial_img = self.fc_k_spatial(img_spatial_feat)
421
+ query_spatial_img = self.fc_q_spatial(img_spatial_feat)
422
+ key_spatial_img_nh = rearrange(key_spatial_img, 'b t (head k) -> head b t k', head=self.n_heads)
423
+ query_spatial_img_nh = rearrange(query_spatial_img, 'b l (head k) -> head b l k', head=self.n_heads)
424
+ dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh)
425
+ dots_img_spatial = dots_img_spatial/ self.scale
426
+
427
+ # Line 59 of pseudo-code
428
+ img_attn_scores = dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial
429
+
430
+ text_attn_probs = self.dropout(torch.softmax(text_attn_scores, dim=-1))
431
+ img_attn_probs = self.dropout(torch.softmax(img_attn_scores, dim=-1))
432
+
433
+ text_context = torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_text_nh)
434
+ img_context = torch.einsum('hblt,hbtv->hblv', img_attn_probs, value_img_nh)
435
+
436
+ context = text_context + img_context
437
+
438
+ embeddings = rearrange(context, 'head b t d -> b t (head d)')
439
+ return self.to_out(embeddings)
440
+
441
+ class DocFormerEncoder(nn.Module):
442
+ def __init__(self, config):
443
+ super().__init__()
444
+ self.config = config
445
+ self.layers = nn.ModuleList([])
446
+ for _ in range(config['num_hidden_layers']):
447
+ encoder_block = nn.ModuleList([
448
+ PreNormAttn(config['hidden_size'],
449
+ MultiModalAttentionLayer(config['hidden_size'],
450
+ config['num_attention_heads'],
451
+ config['max_relative_positions'],
452
+ config['max_position_embeddings'],
453
+ config['hidden_dropout_prob'],
454
+ )
455
+ ),
456
+ PreNorm(config['hidden_size'],
457
+ FeedForward(config['hidden_size'],
458
+ config['hidden_size'] * config['intermediate_ff_size_factor'],
459
+ dropout=config['hidden_dropout_prob']))
460
+ ])
461
+ self.layers.append(encoder_block)
462
+
463
+ def forward(
464
+ self,
465
+ text_feat, # text feat or output from last encoder block
466
+ img_feat,
467
+ text_spatial_feat,
468
+ img_spatial_feat,
469
+ ):
470
+ # Fig 1 encoder part (skip conn for both attn & FF): https://arxiv.org/abs/1706.03762
471
+ # TODO: ensure 1st skip conn (var "skip") in such a multimodal setting makes sense (most likely does)
472
+ for attn, ff in self.layers:
473
+ skip = text_feat + img_feat + text_spatial_feat + img_spatial_feat
474
+ x = attn(text_feat, img_feat, text_spatial_feat, img_spatial_feat) + skip
475
+ x = ff(x) + x
476
+ text_feat = x
477
+ return x
478
+
479
+
480
+ class LanguageFeatureExtractor(nn.Module):
481
+ def __init__(self):
482
+ super().__init__()
483
+ from transformers import LayoutLMForTokenClassification
484
+ layoutlm_dummy = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=1)
485
+ self.embedding_vector = nn.Embedding.from_pretrained(layoutlm_dummy.layoutlm.embeddings.word_embeddings.weight)
486
+
487
+ def forward(self, x):
488
+ return self.embedding_vector(x)
489
+
490
+
491
+
492
+ class ExtractFeatures(nn.Module):
493
+
494
+ '''
495
+ Inputs: dictionary
496
+ Output: v_bar, t_bar, v_bar_s, t_bar_s
497
+ '''
498
+
499
+ def __init__(self, config):
500
+ super().__init__()
501
+ self.visual_feature = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
502
+ self.language_feature = LanguageFeatureExtractor()
503
+ self.spatial_feature = DocFormerEmbeddings(config)
504
+
505
+ def forward(self, encoding):
506
+
507
+ image = encoding['resized_scaled_img']
508
+
509
+ language = encoding['input_ids']
510
+ x_feature = encoding['x_features']
511
+ y_feature = encoding['y_features']
512
+
513
+ v_bar = self.visual_feature(image)
514
+ t_bar = self.language_feature(language)
515
+
516
+ v_bar_s, t_bar_s = self.spatial_feature(x_feature, y_feature)
517
+
518
+ return v_bar, t_bar, v_bar_s, t_bar_s
519
+
520
+
521
+
522
+ class DocFormer(nn.Module):
523
+
524
+ '''
525
+ Easy boiler plate, because this model will just take as an input, the dictionary which is obtained from create_features function
526
+ '''
527
+ def __init__(self, config):
528
+ super().__init__()
529
+ self.config = config
530
+ self.extract_feature = ExtractFeatures(config)
531
+ self.encoder = DocFormerEncoder(config)
532
+ self.dropout = nn.Dropout(config['hidden_dropout_prob'])
533
+
534
+ def forward(self, x ,use_tdi=False):
535
+ v_bar, t_bar, v_bar_s, t_bar_s = self.extract_feature(x,use_tdi)
536
+ features = {'v_bar': v_bar, 't_bar': t_bar, 'v_bar_s': v_bar_s, 't_bar_s': t_bar_s}
537
+ output = self.encoder(features['t_bar'], features['v_bar'], features['t_bar_s'], features['v_bar_s'])
538
+ output = self.dropout(output)
539
+ return output
540
+
541
+
542
+
543
+
544
+
545
+