# coding=utf-8
# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLM-ProphetNet model."""
from .configuration_xlm_prophetnet import XLMProphetNetConfig
from .modeling_prophetnet import (
ProphetNetDecoder,
ProphetNetEncoder,
ProphetNetForCausalLM,
ProphetNetForConditionalGeneration,
ProphetNetModel,
)
from .utils import logging
logger = logging.get_logger(__name__)
_TOKENIZER_FOR_DOC = "XLMProphetNetTokenizer"
XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/xprophetnet-large-wiki100-cased",
# See all ProphetNet models at https://huggingface.co/models?filter=xprophetnet
]
[docs]class XLMProphetNetEncoder(ProphetNetEncoder):
r"""
This class overrides :class:`~transformers.ProphetNetEncoder`. Please check the
superclass for the appropriate documentation alongside usage examples.
Example::
>>> from transformers import XLMProphetNetTokenizer, XLMProphetNetEncoder
>>> import torch
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetEncoder.from_pretrained('patrickvonplaten/xprophetnet-large-uncased-standalone', return_dict=True)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
config_class = XLMProphetNetConfig
[docs]class XLMProphetNetDecoder(ProphetNetDecoder):
r"""
This class overrides :class:`~transformers.ProphetNetDecoder`. Please check the
superclass for the appropriate documentation alongside usage examples.
Example::
>>> from transformers import XLMProphetNetTokenizer, XLMProphetNetDecoder
>>> import torch
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetDecoder.from_pretrained('patrickvonplaten/xprophetnet-large-uncased-standalone', add_cross_attention=False, return_dict=True)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
config_class = XLMProphetNetConfig
[docs]class XLMProphetNetModel(ProphetNetModel):
r"""
This class overrides :class:`~transformers.ProphetNetModel`. Please check the
superclass for the appropriate documentation alongside usage examples.
Example::
>>> from transformers import XLMProphetNetTokenizer, XLMProphetNetModel
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetModel.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
>>> last_hidden_states = outputs.last_hidden_state # main stream hidden states
>>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states
"""
config_class = XLMProphetNetConfig
[docs]class XLMProphetNetForConditionalGeneration(ProphetNetForConditionalGeneration):
r"""
This class overrides :class:`~transformers.ProphetNetForConditionalGeneration`. Please check the
superclass for the appropriate documentation alongside usage examples.
Example::
>>> from transformers import XLMProphetNetTokenizer, XLMProphetNetForConditionalGeneration
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetForConditionalGeneration.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
>>> logits_next_token = outputs.logits # logits to predict next token as usual
>>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens
"""
config_class = XLMProphetNetConfig
[docs]class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
r"""
This class overrides :class:`~transformers.ProphetNetForCausalLM`. Please check the
superclass for the appropriate documentation alongside usage examples.
Example::
>>> from transformers import XLMProphetNetTokenizer, XLMProphetNetForCausalLM
>>> import torch
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = XLMProphetNetForCausalLM.from_pretrained('patrickvonplaten/xprophetnet-decoder-clm-large-uncased', return_dict=True)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # Model can also be used with EncoderDecoder framework
>>> from transformers import BertTokenizer, EncoderDecoderModel
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-uncased-large')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-uncased-large", "patrickvonplaten/xprophetnet-decoder-clm-large-uncased")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"], labels=inputs["input_ids"])
>>> loss = outputs.loss
"""
config_class = XLMProphetNetConfig