| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Tokenization classes for FLMR.""" |
|
|
|
|
| from typing import List, Optional, Union |
|
|
| from transformers.utils import TensorType, logging |
| from transformers.models.bert.tokenization_bert import BertTokenizer |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"} |
|
|
| CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = { |
| "vocab_file": { |
| "LinWeizheDragon/PreFLMR_ViT-L": ( |
| "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt" |
| ), |
| "LinWeizheDragon/FLMR": ( |
| "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt" |
| ), |
| }, |
| "tokenizer_file": { |
| "LinWeizheDragon/PreFLMR_ViT-L": ( |
| "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json" |
| ), |
| "LinWeizheDragon/FLMR": ( |
| "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json" |
| ), |
| }, |
| } |
| QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = { |
| "vocab_file": { |
| "LinWeizheDragon/PreFLMR_ViT-L": ( |
| "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt" |
| ), |
| "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"), |
| }, |
| "tokenizer_file": { |
| "LinWeizheDragon/PreFLMR_ViT-L": ( |
| "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json" |
| ), |
| "LinWeizheDragon/FLMR": ( |
| "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json" |
| ), |
| }, |
| } |
|
|
|
|
| CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { |
| "LinWeizheDragon/PreFLMR_ViT-L": 512, |
| "LinWeizheDragon/FLMR": 512, |
| } |
| QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { |
| "LinWeizheDragon/PreFLMR_ViT-L": 512, |
| "LinWeizheDragon/FLMR": 512, |
| } |
|
|
|
|
| CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = { |
| "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True}, |
| "LinWeizheDragon/FLMR": {"do_lower_case": True}, |
| } |
| QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = { |
| "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True}, |
| "LinWeizheDragon/FLMR": {"do_lower_case": True}, |
| } |
|
|
|
|
| |
| class FLMRContextEncoderTokenizer(BertTokenizer): |
| r""" |
| Construct a FLMRContextEncoder tokenizer. |
| |
| [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation |
| splitting and wordpiece. |
| |
| Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters. |
| """ |
|
|
| vocab_files_names = VOCAB_FILES_NAMES |
| pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP |
| max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
| pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION |
|
|
| def __init__( |
| self, |
| doc_maxlen: Optional[int] = 512, |
| **kwargs, |
| ): |
| super().__init__( |
| doc_maxlen=doc_maxlen, |
| **kwargs, |
| ) |
|
|
| self.doc_maxlen = doc_maxlen |
| self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]") |
|
|
| def __call__( |
| self, |
| text: List[str], |
| padding: Optional[Union[str, bool]] = "max_length", |
| truncation: Optional[Union[bool, str]] = "longest_first", |
| max_length: Optional[int] = 512, |
| return_tensors: Optional[Union[str, TensorType]] = "pt", |
| **kwargs, |
| ): |
| |
| text = [". " + x for x in text] |
|
|
| if max_length > self.doc_maxlen: |
| |
| max_length = self.doc_maxlen |
|
|
| encoding = super().__call__( |
| text, |
| padding=padding, |
| truncation=truncation, |
| return_tensors=return_tensors, |
| max_length=max_length, |
| **kwargs, |
| ) |
|
|
| ids, mask = encoding["input_ids"], encoding["attention_mask"] |
|
|
| |
| ids[:, 1] = self.D_marker_token_id |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| encoding["input_ids"] = ids |
| encoding["attention_mask"] = mask |
|
|
| return encoding |
|
|
|
|
| |
| class FLMRQueryEncoderTokenizer(BertTokenizer): |
| r""" |
| Constructs a FLMRQueryEncoder tokenizer. |
| |
| [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation |
| splitting and wordpiece. |
| |
| Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters. |
| """ |
|
|
| vocab_files_names = VOCAB_FILES_NAMES |
| pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP |
| max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
| pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION |
|
|
| def __init__( |
| self, |
| *args, |
| query_maxlen: Optional[int] = 32, |
| attend_to_mask_tokens: Optional[bool] = False, |
| **kwargs, |
| ): |
| super().__init__( |
| *args, |
| query_maxlen=query_maxlen, |
| attend_to_mask_tokens=attend_to_mask_tokens, |
| **kwargs, |
| ) |
|
|
| self.query_maxlen = query_maxlen |
| self.background_maxlen = 512 - self.query_maxlen + 1 |
| self.attend_to_mask_tokens = attend_to_mask_tokens |
|
|
| self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]") |
|
|
| def __call__( |
| self, |
| text: Union[str, List[str]], |
| padding: Optional[Union[str, bool]] = "max_length", |
| truncation: Optional[Union[bool, str]] = True, |
| max_length: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = "pt", |
| **kwargs, |
| ): |
| if isinstance(text, str): |
| |
| text = [text] |
|
|
| |
| text = [". " + x for x in text] |
|
|
| if max_length is not None: |
| |
| pass |
| else: |
| |
| max_length = self.query_maxlen |
|
|
| encoding = super().__call__( |
| text, |
| padding=padding, |
| truncation=truncation, |
| return_tensors=return_tensors, |
| max_length=max_length, |
| **kwargs, |
| ) |
|
|
| ids, mask = encoding["input_ids"], encoding["attention_mask"] |
|
|
| |
| ids[:, 1] = self.Q_marker_token_id |
| ids[ids == self.pad_token_id] = self.mask_token_id |
|
|
| if self.attend_to_mask_tokens: |
| |
| mask[ids == self.mask_token_id] = 1 |
| assert mask.sum().item() == mask.size(0) * mask.size(1), mask |
|
|
| return {"input_ids": ids, "attention_mask": mask} |
|
|