# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_utils import BatchEncoding
from .tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.getLogger(__name__)
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = [
"facebook/bart-base",
"facebook/bart-large",
"facebook/bart-large-mnli",
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"yjernite/bart_eli5",
]
[docs]class BartTokenizer(RobertaTokenizer):
# merges and vocab same as Roberta
max_model_input_sizes = {m: 1024 for m in _all_bart_models}
pretrained_vocab_files_map = {
"vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models},
}
class BartTokenizerFast(RobertaTokenizerFast):
# merges and vocab same as Roberta
max_model_input_sizes = {m: 1024 for m in _all_bart_models}
pretrained_vocab_files_map = {
"vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models},
}
_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"]
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
class MBartTokenizer(XLMRobertaTokenizer):
"""
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
Other tokenizer methods like encode do not work properly.
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
Examples::
>>> from transformers import MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_translation_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
... )
"""
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
lang_code_to_id = { # NOTE(SS): resize embeddings will break this
"ar_AR": 250001,
"cs_CZ": 250002,
"de_DE": 250003,
"en_XX": 250004,
"es_XX": 250005,
"et_EE": 250006,
"fi_FI": 250007,
"fr_XX": 250008,
"gu_IN": 250009,
"hi_IN": 250010,
"it_IT": 250011,
"ja_XX": 250012,
"kk_KZ": 250013,
"ko_KR": 250014,
"lt_LT": 250015,
"lv_LV": 250016,
"my_MM": 250017,
"ne_NP": 250018,
"nl_XX": 250019,
"ro_RO": 250020,
"ru_RU": 250021,
"si_LK": 250022,
"tr_TR": 250023,
"vi_VN": 250024,
"zh_CN": 250025,
}
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
cur_lang_code = lang_code_to_id["en_XX"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
special_tokens = [self.eos_token_id, self.cur_lang_code]
if token_ids_1 is None:
return token_ids_0 + special_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + special_tokens
def set_lang(self, lang: str) -> None:
"""Set the current language code in order to call tokenizer properly."""
self.cur_lang_code = self.lang_code_to_id[lang]
def prepare_translation_batch(
self,
src_texts: List[str],
src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO",
max_length: Optional[int] = None,
pad_to_max_length: bool = True,
return_tensors: str = "pt",
) -> BatchEncoding:
"""
Arguments:
src_texts: list of src language texts
src_lang: default en_XX (english)
tgt_texts: list of tgt language texts
tgt_lang: default ro_RO (romanian)
max_length: (None) defer to config (1024 for mbart-large-en-ro)
pad_to_max_length: (bool)
Returns:
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
"""
if max_length is None:
max_length = self.max_len
self.cur_lang_code = self.lang_code_to_id[src_lang]
model_inputs: BatchEncoding = self.batch_encode_plus(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
truncation=True,
)
if tgt_texts is None:
return model_inputs
self.cur_lang_code = self.lang_code_to_id[tgt_lang]
decoder_inputs: BatchEncoding = self.batch_encode_plus(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
truncation=True,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
self.cur_lang_code = self.lang_code_to_id[src_lang]
return model_inputs