Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/models
/fairseq_incremental_decoder.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from typing import Dict, Optional | |
from fairseq.incremental_decoding_utils import with_incremental_state | |
from fairseq.models import FairseqDecoder | |
from torch import Tensor | |
logger = logging.getLogger(__name__) | |
class FairseqIncrementalDecoder(FairseqDecoder): | |
"""Base class for incremental decoders. | |
Incremental decoding is a special mode at inference time where the Model | |
only receives a single timestep of input corresponding to the previous | |
output token (for teacher forcing) and must produce the next output | |
*incrementally*. Thus the model must cache any long-term state that is | |
needed about the sequence, e.g., hidden states, convolutional states, etc. | |
Compared to the standard :class:`FairseqDecoder` interface, the incremental | |
decoder interface allows :func:`forward` functions to take an extra keyword | |
argument (*incremental_state*) that can be used to cache state across | |
time-steps. | |
The :class:`FairseqIncrementalDecoder` interface also defines the | |
:func:`reorder_incremental_state` method, which is used during beam search | |
to select and reorder the incremental state based on the selection of beams. | |
To learn more about how incremental decoding works, refer to `this blog | |
<http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_. | |
""" | |
def __init__(self, dictionary): | |
super().__init__(dictionary) | |
def forward( | |
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs | |
): | |
""" | |
Args: | |
prev_output_tokens (LongTensor): shifted output tokens of shape | |
`(batch, tgt_len)`, for teacher forcing | |
encoder_out (dict, optional): output from the encoder, used for | |
encoder-side attention | |
incremental_state (dict, optional): dictionary used for storing | |
state during :ref:`Incremental decoding` | |
Returns: | |
tuple: | |
- the decoder's output of shape `(batch, tgt_len, vocab)` | |
- a dictionary with any model-specific outputs | |
""" | |
raise NotImplementedError | |
def extract_features( | |
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs | |
): | |
""" | |
Returns: | |
tuple: | |
- the decoder's features of shape `(batch, tgt_len, embed_dim)` | |
- a dictionary with any model-specific outputs | |
""" | |
raise NotImplementedError | |
def reorder_incremental_state( | |
self, | |
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
new_order: Tensor, | |
): | |
"""Reorder incremental state. | |
This will be called when the order of the input has changed from the | |
previous time step. A typical use case is beam search, where the input | |
order changes between time steps based on the selection of beams. | |
""" | |
pass | |
def reorder_incremental_state_scripting( | |
self, | |
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
new_order: Tensor, | |
): | |
"""Main entry point for reordering the incremental state. | |
Due to limitations in TorchScript, we call this function in | |
:class:`fairseq.sequence_generator.SequenceGenerator` instead of | |
calling :func:`reorder_incremental_state` directly. | |
""" | |
for module in self.modules(): | |
if hasattr(module, "reorder_incremental_state"): | |
result = module.reorder_incremental_state(incremental_state, new_order) | |
if result is not None: | |
incremental_state = result | |
def set_beam_size(self, beam_size): | |
"""Sets the beam size in the decoder and all children.""" | |
if getattr(self, "_beam_size", -1) != beam_size: | |
seen = set() | |
def apply_set_beam_size(module): | |
if ( | |
module != self | |
and hasattr(module, "set_beam_size") | |
and module not in seen | |
): | |
seen.add(module) | |
module.set_beam_size(beam_size) | |
self.apply(apply_set_beam_size) | |
self._beam_size = beam_size | |