euiyulsong commited on
Commit
eac156b
1 Parent(s): 7368bdd

Create fid.py

Browse files
Files changed (1) hide show
  1. fid.py +357 -0
fid.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import types
8
+ import torch
9
+ import transformers
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ import numpy as np
14
+
15
+ class FiDT5(transformers.T5ForConditionalGeneration):
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+ self.wrap_encoder()
19
+
20
+ def forward_(self, **kwargs):
21
+ if 'input_ids' in kwargs:
22
+ kwargs['input_ids'] = kwargs['input_ids'].view(kwargs['input_ids'].size(0), -1)
23
+ if 'attention_mask' in kwargs:
24
+ kwargs['attention_mask'] = kwargs['attention_mask'].view(kwargs['attention_mask'].size(0), -1)
25
+
26
+ return super(FiDT5, self).forward(
27
+ **kwargs
28
+ )
29
+
30
+ # We need to resize as B x (N * L) instead of (B * N) x L here
31
+ # because the T5 forward method uses the input tensors to infer
32
+ # dimensions used in the decoder.
33
+ # EncoderWrapper resizes the inputs as (B * N) x L.
34
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
35
+ if input_ids != None:
36
+ # inputs might have already be resized in the generate method
37
+ if input_ids.dim() == 3:
38
+ self.encoder.n_passages = input_ids.size(1)
39
+ input_ids = input_ids.view(input_ids.size(0), -1)
40
+ if attention_mask != None:
41
+ attention_mask = attention_mask.view(attention_mask.size(0), -1)
42
+ return super().forward(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ **kwargs
46
+ )
47
+
48
+ # We need to resize the inputs here, as the generate method expect 2D tensors
49
+ def generate(self, input_ids, attention_mask, max_length):
50
+ self.encoder.n_passages = input_ids.size(1)
51
+ return super().generate(
52
+ input_ids=input_ids.view(input_ids.size(0), -1),
53
+ attention_mask=attention_mask.view(attention_mask.size(0), -1),
54
+ max_length=max_length
55
+ )
56
+
57
+ def wrap_encoder(self, use_checkpoint=False):
58
+ """
59
+ Wrap T5 encoder to obtain a Fusion-in-Decoder model.
60
+ """
61
+ self.encoder = EncoderWrapper(self.encoder, use_checkpoint=use_checkpoint)
62
+
63
+ def unwrap_encoder(self):
64
+ """
65
+ Unwrap Fusion-in-Decoder encoder, useful to load T5 weights.
66
+ """
67
+ self.encoder = self.encoder.encoder
68
+ block = []
69
+ for mod in self.encoder.block:
70
+ block.append(mod.module)
71
+ block = nn.ModuleList(block)
72
+ self.encoder.block = block
73
+
74
+ def load_t5(self, state_dict):
75
+ self.unwrap_encoder()
76
+ self.load_state_dict(state_dict)
77
+ self.wrap_encoder()
78
+
79
+ def set_checkpoint(self, use_checkpoint):
80
+ """
81
+ Enable or disable checkpointing in the encoder.
82
+ See https://pytorch.org/docs/stable/checkpoint.html
83
+ """
84
+ for mod in self.encoder.encoder.block:
85
+ mod.use_checkpoint = use_checkpoint
86
+
87
+ def reset_score_storage(self):
88
+ """
89
+ Reset score storage, only used when cross-attention scores are saved
90
+ to train a retriever.
91
+ """
92
+ for mod in self.decoder.block:
93
+ mod.layer[1].EncDecAttention.score_storage = None
94
+
95
+ def get_crossattention_scores(self, context_mask):
96
+ """
97
+ Cross-attention scores are aggregated to obtain a single scalar per
98
+ passage. This scalar can be seen as a similarity score between the
99
+ question and the input passage. It is obtained by averaging the
100
+ cross-attention scores obtained on the first decoded token over heads,
101
+ layers, and tokens of the input passage.
102
+ More details in Distilling Knowledge from Reader to Retriever:
103
+ https://arxiv.org/abs/2012.04584.
104
+ """
105
+ scores = []
106
+ n_passages = context_mask.size(1)
107
+ for mod in self.decoder.block:
108
+ scores.append(mod.layer[1].EncDecAttention.score_storage)
109
+ scores = torch.cat(scores, dim=2)
110
+ bsz, n_heads, n_layers, _ = scores.size()
111
+ # batch_size, n_head, n_layers, n_passages, text_maxlength
112
+ scores = scores.view(bsz, n_heads, n_layers, n_passages, -1)
113
+ scores = scores.masked_fill(~context_mask[:, None, None], 0.)
114
+ scores = scores.sum(dim=[1, 2, 4])
115
+ ntokens = context_mask.sum(dim=[2]) * n_layers * n_heads
116
+ scores = scores/ntokens
117
+ return scores
118
+
119
+ def overwrite_forward_crossattention(self):
120
+ """
121
+ Replace cross-attention forward function, only used to save
122
+ cross-attention scores.
123
+ """
124
+ for mod in self.decoder.block:
125
+ attn = mod.layer[1].EncDecAttention
126
+ attn.forward = types.MethodType(cross_attention_forward, attn)
127
+
128
+ class EncoderWrapper(torch.nn.Module):
129
+ """
130
+ Encoder Wrapper for T5 Wrapper to obtain a Fusion-in-Decoder model.
131
+ """
132
+ def __init__(self, encoder, use_checkpoint=False):
133
+ super().__init__()
134
+
135
+ self.encoder = encoder
136
+ apply_checkpoint_wrapper(self.encoder, use_checkpoint)
137
+
138
+ def forward(self, input_ids=None, attention_mask=None, **kwargs,):
139
+ # total_length = n_passages * passage_length
140
+ bsz, total_length = input_ids.shape
141
+ passage_length = total_length // self.n_passages
142
+ input_ids = input_ids.view(bsz*self.n_passages, passage_length)
143
+ attention_mask = attention_mask.view(bsz*self.n_passages, passage_length)
144
+ outputs = self.encoder(input_ids, attention_mask, **kwargs)
145
+ outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:]
146
+ return outputs
147
+
148
+ class CheckpointWrapper(torch.nn.Module):
149
+ """
150
+ Wrapper replacing None outputs by empty tensors, which allows the use of
151
+ checkpointing.
152
+ """
153
+ def __init__(self, module, use_checkpoint=False):
154
+ super().__init__()
155
+ self.module = module
156
+ self.use_checkpoint = use_checkpoint
157
+
158
+ def forward(self, hidden_states, attention_mask, position_bias, **kwargs):
159
+ if self.use_checkpoint and self.training:
160
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
161
+ def custom_forward(*inputs):
162
+ output = self.module(*inputs, **kwargs)
163
+ empty = torch.tensor(
164
+ [],
165
+ dtype=torch.float,
166
+ device=output[0].device,
167
+ requires_grad=True)
168
+ output = tuple(x if x is not None else empty for x in output)
169
+ return output
170
+
171
+ output = torch.utils.checkpoint.checkpoint(
172
+ custom_forward,
173
+ hidden_states,
174
+ attention_mask,
175
+ position_bias
176
+ )
177
+ output = tuple(x if x.size() != 0 else None for x in output)
178
+ else:
179
+ output = self.module(hidden_states, attention_mask, position_bias, **kwargs)
180
+ return output
181
+
182
+ def apply_checkpoint_wrapper(t5stack, use_checkpoint):
183
+ """
184
+ Wrap each block of the encoder to enable checkpointing.
185
+ """
186
+ block = []
187
+ for mod in t5stack.block:
188
+ wrapped_mod = CheckpointWrapper(mod, use_checkpoint)
189
+ block.append(wrapped_mod)
190
+ block = nn.ModuleList(block)
191
+ t5stack.block = block
192
+
193
+ def cross_attention_forward(
194
+ self,
195
+ input,
196
+ mask=None,
197
+ kv=None,
198
+ position_bias=None,
199
+ past_key_value_state=None,
200
+ head_mask=None,
201
+ query_length=None,
202
+ use_cache=False,
203
+ output_attentions=False,
204
+ ):
205
+ """
206
+ This only works for computing cross attention over the input
207
+ """
208
+ assert(kv != None)
209
+ assert(head_mask == None)
210
+ assert(position_bias != None or self.has_relative_attention_bias)
211
+
212
+ bsz, qlen, dim = input.size()
213
+ n_heads, d_heads = self.n_heads, self.d_kv
214
+ klen = kv.size(1)
215
+
216
+ q = self.q(input).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
217
+ if past_key_value_state == None:
218
+ k = self.k(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
219
+ v = self.v(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
220
+ else:
221
+ k, v = past_key_value_state
222
+
223
+ scores = torch.einsum("bnqd,bnkd->bnqk", q, k)
224
+
225
+ if mask is not None:
226
+ scores += mask
227
+
228
+ if position_bias is None:
229
+ position_bias = self.compute_bias(qlen, klen)
230
+ scores += position_bias
231
+
232
+ if self.score_storage is None:
233
+ self.score_storage = scores
234
+
235
+ attn = F.softmax(scores.float(), dim=-1).type_as(scores)
236
+ attn = F.dropout(attn, p=self.dropout, training=self.training)
237
+
238
+ output = torch.matmul(attn, v)
239
+ output = output.transpose(1, 2).contiguous().view(bsz, -1, self.inner_dim)
240
+ output = self.o(output)
241
+
242
+ if use_cache:
243
+ output = (output,) + ((k, v),)
244
+ else:
245
+ output = (output,) + (None,)
246
+
247
+ if output_attentions:
248
+ output = output + (attn,)
249
+
250
+ if self.has_relative_attention_bias:
251
+ output = output + (position_bias,)
252
+
253
+ return output
254
+
255
+ class RetrieverConfig(transformers.BertConfig):
256
+
257
+ def __init__(self,
258
+ indexing_dimension=768,
259
+ apply_question_mask=False,
260
+ apply_passage_mask=False,
261
+ extract_cls=False,
262
+ passage_maxlength=200,
263
+ question_maxlength=40,
264
+ projection=True,
265
+ **kwargs):
266
+ super().__init__(**kwargs)
267
+ self.indexing_dimension = indexing_dimension
268
+ self.apply_question_mask = apply_question_mask
269
+ self.apply_passage_mask = apply_passage_mask
270
+ self.extract_cls=extract_cls
271
+ self.passage_maxlength = passage_maxlength
272
+ self.question_maxlength = question_maxlength
273
+ self.projection = projection
274
+
275
+ class Retriever(transformers.PreTrainedModel):
276
+
277
+ config_class = RetrieverConfig
278
+ base_model_prefix = "retriever"
279
+
280
+ def __init__(self, config, initialize_wBERT=False):
281
+ super().__init__(config)
282
+ assert config.projection or config.indexing_dimension == 768, \
283
+ 'If no projection then indexing dimension must be equal to 768'
284
+ self.config = config
285
+ if initialize_wBERT:
286
+ self.model = transformers.BertModel.from_pretrained('bert-base-uncased')
287
+ else:
288
+ self.model = transformers.BertModel(config)
289
+ if self.config.projection:
290
+ self.proj = nn.Linear(
291
+ self.model.config.hidden_size,
292
+ self.config.indexing_dimension
293
+ )
294
+ self.norm = nn.LayerNorm(self.config.indexing_dimension)
295
+ self.loss_fct = torch.nn.KLDivLoss()
296
+
297
+ def forward(self,
298
+ question_ids,
299
+ question_mask,
300
+ passage_ids,
301
+ passage_mask,
302
+ gold_score=None):
303
+ question_output = self.embed_text(
304
+ text_ids=question_ids,
305
+ text_mask=question_mask,
306
+ apply_mask=self.config.apply_question_mask,
307
+ extract_cls=self.config.extract_cls,
308
+ )
309
+ bsz, n_passages, plen = passage_ids.size()
310
+ passage_ids = passage_ids.view(bsz * n_passages, plen)
311
+ passage_mask = passage_mask.view(bsz * n_passages, plen)
312
+ passage_output = self.embed_text(
313
+ text_ids=passage_ids,
314
+ text_mask=passage_mask,
315
+ apply_mask=self.config.apply_passage_mask,
316
+ extract_cls=self.config.extract_cls,
317
+ )
318
+
319
+ score = torch.einsum(
320
+ 'bd,bid->bi',
321
+ question_output,
322
+ passage_output.view(bsz, n_passages, -1)
323
+ )
324
+ score = score / np.sqrt(question_output.size(-1))
325
+ if gold_score is not None:
326
+ loss = self.kldivloss(score, gold_score)
327
+ else:
328
+ loss = None
329
+
330
+ return question_output, passage_output, score, loss
331
+
332
+ def embed_text(self, text_ids, text_mask, apply_mask=False, extract_cls=False):
333
+ text_output = self.model(
334
+ input_ids=text_ids,
335
+ attention_mask=text_mask if apply_mask else None
336
+ )
337
+ if type(text_output) is not tuple:
338
+ text_output.to_tuple()
339
+ text_output = text_output[0]
340
+ if self.config.projection:
341
+ text_output = self.proj(text_output)
342
+ text_output = self.norm(text_output)
343
+
344
+ if extract_cls:
345
+ text_output = text_output[:, 0]
346
+ else:
347
+ if apply_mask:
348
+ text_output = text_output.masked_fill(~text_mask[:, :, None], 0.)
349
+ text_output = torch.sum(text_output, dim=1) / torch.sum(text_mask, dim=1)[:, None]
350
+ else:
351
+ text_output = torch.mean(text_output, dim=1)
352
+ return text_output
353
+
354
+ def kldivloss(self, score, gold_score):
355
+ gold_score = torch.softmax(gold_score, dim=-1)
356
+ score = torch.nn.functional.log_softmax(score, dim=-1)
357
+ return self.loss_fct(score, gold_score)