# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.data import encoders class RobertaHubInterface(nn.Module): """A simple PyTorch Hub interface to RoBERTa. Usage: https://github.com/pytorch/fairseq/tree/main/examples/roberta """ def __init__(self, cfg, task, model): super().__init__() self.cfg = cfg self.task = task self.model = model self.bpe = encoders.build_bpe(cfg.bpe) # this is useful for determining the device self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) @property def device(self): return self._float_tensor.device def encode( self, sentence: str, *addl_sentences, no_separator=False ) -> torch.LongTensor: """ BPE-encode a sentence (or multiple sentences). Every sequence begins with a beginning-of-sentence (``) symbol. Every sentence ends with an end-of-sentence (``) and we use an extra end-of-sentence (``) as a separator. Example (single sentence): ` a b c ` Example (sentence pair): ` d e f 1 2 3 ` The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE requires leading spaces. For example:: >>> roberta.encode('Hello world').tolist() [0, 31414, 232, 2] >>> roberta.encode(' world').tolist() [0, 232, 2] >>> roberta.encode('world').tolist() [0, 8331, 2] """ bpe_sentence = " " + self.bpe.encode(sentence) + " " for s in addl_sentences: bpe_sentence += " " if not no_separator else "" bpe_sentence += " " + self.bpe.encode(s) + " " tokens = self.task.source_dictionary.encode_line( bpe_sentence, append_eos=False, add_if_not_exist=False ) return tokens.long() def decode(self, tokens: torch.LongTensor): assert tokens.dim() == 1 tokens = tokens.numpy() if tokens[0] == self.task.source_dictionary.bos(): tokens = tokens[1:] # remove eos_mask = tokens == self.task.source_dictionary.eos() doc_mask = eos_mask[1:] & eos_mask[:-1] sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) sentences = [ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences ] if len(sentences) == 1: return sentences[0] return sentences def extract_features( self, tokens: torch.LongTensor, return_all_hiddens: bool = False ) -> torch.Tensor: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) if tokens.size(-1) > self.model.max_positions(): raise ValueError( "tokens exceeds maximum length: {} > {}".format( tokens.size(-1), self.model.max_positions() ) ) features, extra = self.model( tokens.to(device=self.device), features_only=True, return_all_hiddens=return_all_hiddens, ) if return_all_hiddens: # convert from T x B x C -> B x T x C inner_states = extra["inner_states"] return [inner_state.transpose(0, 1) for inner_state in inner_states] else: return features # just the last layer's features def register_classification_head( self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs ): self.model.register_classification_head( name, num_classes=num_classes, embedding_size=embedding_size, **kwargs ) def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False): features = self.extract_features(tokens.to(device=self.device)) logits = self.model.classification_heads[head](features) if return_logits: return logits return F.log_softmax(logits, dim=-1) def extract_features_aligned_to_words( self, sentence: str, return_all_hiddens: bool = False ) -> torch.Tensor: """Extract RoBERTa features, aligned to spaCy's word-level tokenizer.""" from fairseq.models.roberta import alignment_utils from spacy.tokens import Doc nlp = alignment_utils.spacy_nlp() tokenizer = alignment_utils.spacy_tokenizer() # tokenize both with GPT-2 BPE and spaCy bpe_toks = self.encode(sentence) spacy_toks = tokenizer(sentence) spacy_toks_ws = [t.text_with_ws for t in tokenizer(sentence)] alignment = alignment_utils.align_bpe_to_words(self, bpe_toks, spacy_toks_ws) # extract features and align them features = self.extract_features( bpe_toks, return_all_hiddens=return_all_hiddens ) features = features.squeeze(0) aligned_feats = alignment_utils.align_features_to_words( self, features, alignment ) # wrap in spaCy Doc doc = Doc( nlp.vocab, words=[""] + [x.text for x in spacy_toks] + [""], spaces=[True] + [x.endswith(" ") for x in spacy_toks_ws[:-1]] + [True, False], ) assert len(doc) == aligned_feats.size(0) doc.user_token_hooks["vector"] = lambda token: aligned_feats[token.i] return doc def fill_mask(self, masked_input: str, topk: int = 5): masked_token = "" assert ( masked_token in masked_input and masked_input.count(masked_token) == 1 ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format( masked_token ) text_spans = masked_input.split(masked_token) text_spans_bpe = ( (" {0} ".format(masked_token)) .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans]) .strip() ) tokens = self.task.source_dictionary.encode_line( " " + text_spans_bpe + " ", append_eos=False, add_if_not_exist=False, ) masked_index = (tokens == self.task.mask_idx).nonzero(as_tuple=False) if tokens.dim() == 1: tokens = tokens.unsqueeze(0) with utils.model_eval(self.model): features, extra = self.model( tokens.long().to(device=self.device), features_only=False, return_all_hiddens=False, ) logits = features[0, masked_index, :].squeeze() prob = logits.softmax(dim=0) values, index = prob.topk(k=topk, dim=0) topk_predicted_token_bpe = self.task.source_dictionary.string(index) topk_filled_outputs = [] for index, predicted_token_bpe in enumerate( topk_predicted_token_bpe.split(" ") ): predicted_token = self.bpe.decode(predicted_token_bpe) # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306 if predicted_token_bpe.startswith("\u2581"): predicted_token = " " + predicted_token if " {0}".format(masked_token) in masked_input: topk_filled_outputs.append( ( masked_input.replace( " {0}".format(masked_token), predicted_token ), values[index].item(), predicted_token, ) ) else: topk_filled_outputs.append( ( masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token, ) ) return topk_filled_outputs def disambiguate_pronoun(self, sentence: str) -> bool: """ Usage:: >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.') True >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.') 'The trophy' """ assert hasattr( self.task, "disambiguate_pronoun" ), "roberta.disambiguate_pronoun() requires a model trained with the WSC task." with utils.model_eval(self.model): return self.task.disambiguate_pronoun( self.model, sentence, use_cuda=self.device.type == "cuda" )