# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import types import torch import transformers import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss import numpy as np class FiDT5(transformers.T5ForConditionalGeneration): def __init__(self, config): super().__init__(config) self.wrap_encoder() def forward_(self, **kwargs): if 'input_ids' in kwargs: kwargs['input_ids'] = kwargs['input_ids'].view(kwargs['input_ids'].size(0), -1) if 'attention_mask' in kwargs: kwargs['attention_mask'] = kwargs['attention_mask'].view(kwargs['attention_mask'].size(0), -1) return super(FiDT5, self).forward( **kwargs ) # We need to resize as B x (N * L) instead of (B * N) x L here # because the T5 forward method uses the input tensors to infer # dimensions used in the decoder. # EncoderWrapper resizes the inputs as (B * N) x L. def forward(self, input_ids=None, attention_mask=None, **kwargs): if input_ids != None: # inputs might have already be resized in the generate method if input_ids.dim() == 3: self.encoder.n_passages = input_ids.size(1) input_ids = input_ids.view(input_ids.size(0), -1) if attention_mask != None: attention_mask = attention_mask.view(attention_mask.size(0), -1) return super().forward( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) # We need to resize the inputs here, as the generate method expect 2D tensors def generate(self, input_ids, attention_mask, max_length): self.encoder.n_passages = input_ids.size(1) return super().generate( input_ids=input_ids.view(input_ids.size(0), -1), attention_mask=attention_mask.view(attention_mask.size(0), -1), max_length=max_length ) def wrap_encoder(self, use_checkpoint=False): """ Wrap T5 encoder to obtain a Fusion-in-Decoder model. """ self.encoder = EncoderWrapper(self.encoder, use_checkpoint=use_checkpoint) def unwrap_encoder(self): """ Unwrap Fusion-in-Decoder encoder, useful to load T5 weights. """ self.encoder = self.encoder.encoder block = [] for mod in self.encoder.block: block.append(mod.module) block = nn.ModuleList(block) self.encoder.block = block def load_t5(self, state_dict): self.unwrap_encoder() self.load_state_dict(state_dict) self.wrap_encoder() def set_checkpoint(self, use_checkpoint): """ Enable or disable checkpointing in the encoder. See https://pytorch.org/docs/stable/checkpoint.html """ for mod in self.encoder.encoder.block: mod.use_checkpoint = use_checkpoint def reset_score_storage(self): """ Reset score storage, only used when cross-attention scores are saved to train a retriever. """ for mod in self.decoder.block: mod.layer[1].EncDecAttention.score_storage = None def get_crossattention_scores(self, context_mask): """ Cross-attention scores are aggregated to obtain a single scalar per passage. This scalar can be seen as a similarity score between the question and the input passage. It is obtained by averaging the cross-attention scores obtained on the first decoded token over heads, layers, and tokens of the input passage. More details in Distilling Knowledge from Reader to Retriever: https://arxiv.org/abs/2012.04584. """ scores = [] n_passages = context_mask.size(1) for mod in self.decoder.block: scores.append(mod.layer[1].EncDecAttention.score_storage) scores = torch.cat(scores, dim=2) bsz, n_heads, n_layers, _ = scores.size() # batch_size, n_head, n_layers, n_passages, text_maxlength scores = scores.view(bsz, n_heads, n_layers, n_passages, -1) scores = scores.masked_fill(~context_mask[:, None, None], 0.) scores = scores.sum(dim=[1, 2, 4]) ntokens = context_mask.sum(dim=[2]) * n_layers * n_heads scores = scores/ntokens return scores def overwrite_forward_crossattention(self): """ Replace cross-attention forward function, only used to save cross-attention scores. """ for mod in self.decoder.block: attn = mod.layer[1].EncDecAttention attn.forward = types.MethodType(cross_attention_forward, attn) class EncoderWrapper(torch.nn.Module): """ Encoder Wrapper for T5 Wrapper to obtain a Fusion-in-Decoder model. """ def __init__(self, encoder, use_checkpoint=False): super().__init__() self.encoder = encoder apply_checkpoint_wrapper(self.encoder, use_checkpoint) def forward(self, input_ids=None, attention_mask=None, **kwargs,): # total_length = n_passages * passage_length bsz, total_length = input_ids.shape passage_length = total_length // self.n_passages input_ids = input_ids.view(bsz*self.n_passages, passage_length) attention_mask = attention_mask.view(bsz*self.n_passages, passage_length) outputs = self.encoder(input_ids, attention_mask, **kwargs) outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:] return outputs class CheckpointWrapper(torch.nn.Module): """ Wrapper replacing None outputs by empty tensors, which allows the use of checkpointing. """ def __init__(self, module, use_checkpoint=False): super().__init__() self.module = module self.use_checkpoint = use_checkpoint def forward(self, hidden_states, attention_mask, position_bias, **kwargs): if self.use_checkpoint and self.training: kwargs = {k: v for k, v in kwargs.items() if v is not None} def custom_forward(*inputs): output = self.module(*inputs, **kwargs) empty = torch.tensor( [], dtype=torch.float, device=output[0].device, requires_grad=True) output = tuple(x if x is not None else empty for x in output) return output output = torch.utils.checkpoint.checkpoint( custom_forward, hidden_states, attention_mask, position_bias ) output = tuple(x if x.size() != 0 else None for x in output) else: output = self.module(hidden_states, attention_mask, position_bias, **kwargs) return output def apply_checkpoint_wrapper(t5stack, use_checkpoint): """ Wrap each block of the encoder to enable checkpointing. """ block = [] for mod in t5stack.block: wrapped_mod = CheckpointWrapper(mod, use_checkpoint) block.append(wrapped_mod) block = nn.ModuleList(block) t5stack.block = block def cross_attention_forward( self, input, mask=None, kv=None, position_bias=None, past_key_value_state=None, head_mask=None, query_length=None, use_cache=False, output_attentions=False, ): """ This only works for computing cross attention over the input """ assert(kv != None) assert(head_mask == None) assert(position_bias != None or self.has_relative_attention_bias) bsz, qlen, dim = input.size() n_heads, d_heads = self.n_heads, self.d_kv klen = kv.size(1) q = self.q(input).view(bsz, -1, n_heads, d_heads).transpose(1, 2) if past_key_value_state == None: k = self.k(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2) v = self.v(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2) else: k, v = past_key_value_state scores = torch.einsum("bnqd,bnkd->bnqk", q, k) if mask is not None: scores += mask if position_bias is None: position_bias = self.compute_bias(qlen, klen) scores += position_bias if self.score_storage is None: self.score_storage = scores attn = F.softmax(scores.float(), dim=-1).type_as(scores) attn = F.dropout(attn, p=self.dropout, training=self.training) output = torch.matmul(attn, v) output = output.transpose(1, 2).contiguous().view(bsz, -1, self.inner_dim) output = self.o(output) if use_cache: output = (output,) + ((k, v),) else: output = (output,) + (None,) if output_attentions: output = output + (attn,) if self.has_relative_attention_bias: output = output + (position_bias,) return output class RetrieverConfig(transformers.BertConfig): def __init__(self, indexing_dimension=768, apply_question_mask=False, apply_passage_mask=False, extract_cls=False, passage_maxlength=200, question_maxlength=40, projection=True, **kwargs): super().__init__(**kwargs) self.indexing_dimension = indexing_dimension self.apply_question_mask = apply_question_mask self.apply_passage_mask = apply_passage_mask self.extract_cls=extract_cls self.passage_maxlength = passage_maxlength self.question_maxlength = question_maxlength self.projection = projection class Retriever(transformers.PreTrainedModel): config_class = RetrieverConfig base_model_prefix = "retriever" def __init__(self, config, initialize_wBERT=False): super().__init__(config) assert config.projection or config.indexing_dimension == 768, \ 'If no projection then indexing dimension must be equal to 768' self.config = config if initialize_wBERT: self.model = transformers.BertModel.from_pretrained('bert-base-uncased') else: self.model = transformers.BertModel(config) if self.config.projection: self.proj = nn.Linear( self.model.config.hidden_size, self.config.indexing_dimension ) self.norm = nn.LayerNorm(self.config.indexing_dimension) self.loss_fct = torch.nn.KLDivLoss() def forward(self, question_ids, question_mask, passage_ids, passage_mask, gold_score=None): question_output = self.embed_text( text_ids=question_ids, text_mask=question_mask, apply_mask=self.config.apply_question_mask, extract_cls=self.config.extract_cls, ) bsz, n_passages, plen = passage_ids.size() passage_ids = passage_ids.view(bsz * n_passages, plen) passage_mask = passage_mask.view(bsz * n_passages, plen) passage_output = self.embed_text( text_ids=passage_ids, text_mask=passage_mask, apply_mask=self.config.apply_passage_mask, extract_cls=self.config.extract_cls, ) score = torch.einsum( 'bd,bid->bi', question_output, passage_output.view(bsz, n_passages, -1) ) score = score / np.sqrt(question_output.size(-1)) if gold_score is not None: loss = self.kldivloss(score, gold_score) else: loss = None return question_output, passage_output, score, loss def embed_text(self, text_ids, text_mask, apply_mask=False, extract_cls=False): text_output = self.model( input_ids=text_ids, attention_mask=text_mask if apply_mask else None ) if type(text_output) is not tuple: text_output.to_tuple() text_output = text_output[0] if self.config.projection: text_output = self.proj(text_output) text_output = self.norm(text_output) if extract_cls: text_output = text_output[:, 0] else: if apply_mask: text_output = text_output.masked_fill(~text_mask[:, :, None], 0.) text_output = torch.sum(text_output, dim=1) / torch.sum(text_mask, dim=1)[:, None] else: text_output = torch.mean(text_output, dim=1) return text_output def kldivloss(self, score, gold_score): gold_score = torch.softmax(gold_score, dim=-1) score = torch.nn.functional.log_softmax(score, dim=-1) return self.loss_fct(score, gold_score)