# Source code for transformers.tokenization_bart

# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from typing import List, Optional

from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_utils_base import BatchEncoding
from .utils import logging

logger = logging.get_logger(__name__)

# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = [
"yjernite/bart_eli5",
# This is not exhaustive: see https://huggingface.co/models?filter=bart
]

[docs]class BartTokenizer(RobertaTokenizer):
# merges and vocab same as Roberta
max_model_input_sizes = {m: 1024 for m in _all_bart_models}
pretrained_vocab_files_map = {
"vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models},
}

[docs]    def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
return_tensors: str = "None",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""

Prepare a batch that can be passed directly to an instance of :class:~transformers.BartModel.

Args:
src_texts: (:obj:List[str]):
List of documents to summarize or source language texts.
tgt_texts: (:obj:List[str], optional):
List of summaries or target language texts.
max_length (:obj:int, optional):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:None, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:int, optional):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:None, this will use the max_length value.
padding (:obj:bool, :obj:str or :class:~transformers.tokenization_utils_base.PaddingStrategy, optional, defaults to :obj:False):
Activates and controls padding. Accepts the following values:

* :obj:True or :obj:'longest': Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:'max_length': Pad to a maximum length specified with the argument :obj:max_length or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:False or :obj:'do_not_pad' (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:str or :class:~transformers.tokenization_utils_base.TensorType, optional, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:

* :obj:'tf': Return TensorFlow :obj:tf.constant objects.
* :obj:'pt': Return PyTorch :obj:torch.Tensor objects.
* :obj:'np': Return Numpy :obj:np.ndarray objects.
truncation (:obj:bool, :obj:str or :class:~transformers.tokenization_utils_base.TruncationStrategy, optional, defaults to :obj:True):
Activates and controls truncation. Accepts the following values:

* :obj:True or :obj:'longest_first': Truncate to a maximum length specified with the argument
:obj:max_length or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:'only_first': Truncate to a maximum length specified with the argument :obj:max_length or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:'only_second': Truncate to a maximum length specified with the argument :obj:max_length or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:False or :obj:'do_not_truncate' (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:self.__call__.

Returns:
:class:~transformers.BatchEncoding: A :class:~transformers.BatchEncoding with the following fields:

- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts

The full set of keys [input_ids, attention_mask, decoder_input_ids,  decoder_attention_mask],
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels = self(
tgt_texts,
return_tensors=return_tensors,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs

class BartTokenizerFast(RobertaTokenizerFast):
# merges and vocab same as Roberta
max_model_input_sizes = {m: 1024 for m in _all_bart_models}
pretrained_vocab_files_map = {
"vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models},
}

def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
return_tensors: str = "None",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""

Prepare a batch that can be passed directly to an instance of :class:~transformers.BartModel.

Args:
src_texts: (:obj:List[str]):
List of documents to summarize or source language texts.
tgt_texts: (:obj:List[str], optional):
List of summaries or target language texts.
max_length (:obj:int, optional):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:None, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:int, optional):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:None, this will use the max_length value.
padding (:obj:bool, :obj:str or :class:~transformers.tokenization_utils_base.PaddingStrategy, optional, defaults to :obj:False):
Activates and controls padding. Accepts the following values:

* :obj:True or :obj:'longest': Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:'max_length': Pad to a maximum length specified with the argument :obj:max_length or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:False or :obj:'do_not_pad' (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:str or :class:~transformers.tokenization_utils_base.TensorType, optional, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:

* :obj:'tf': Return TensorFlow :obj:tf.constant objects.
* :obj:'pt': Return PyTorch :obj:torch.Tensor objects.
* :obj:'np': Return Numpy :obj:np.ndarray objects.
truncation (:obj:bool, :obj:str or :class:~transformers.tokenization_utils_base.TruncationStrategy, optional, defaults to :obj:True):
Activates and controls truncation. Accepts the following values:

* :obj:True or :obj:'longest_first': Truncate to a maximum length specified with the argument
:obj:max_length or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:'only_first': Truncate to a maximum length specified with the argument :obj:max_length or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:'only_second': Truncate to a maximum length specified with the argument :obj:max_length or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:False or :obj:'do_not_truncate' (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:self.__call__.

Returns:
:class:~transformers.BatchEncoding: A :class:~transformers.BatchEncoding with the following fields:

- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.

The full set of keys [input_ids, attention_mask, decoder_input_ids,  decoder_attention_mask],
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels = self(
tgt_texts,