|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import torch.nn as nn |
|
from fairseq import utils |
|
from torch import Tensor |
|
|
|
|
|
class FairseqDecoder(nn.Module): |
|
"""Base class for decoders.""" |
|
|
|
def __init__(self, dictionary): |
|
super().__init__() |
|
self.dictionary = dictionary |
|
self.onnx_trace = False |
|
self.adaptive_softmax = None |
|
|
|
def forward(self, prev_output_tokens, encoder_out=None, **kwargs): |
|
""" |
|
Args: |
|
prev_output_tokens (LongTensor): shifted output tokens of shape |
|
`(batch, tgt_len)`, for teacher forcing |
|
encoder_out (dict, optional): output from the encoder, used for |
|
encoder-side attention |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's output of shape `(batch, tgt_len, vocab)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
x, extra = self.extract_features( |
|
prev_output_tokens, encoder_out=encoder_out, **kwargs |
|
) |
|
x = self.output_layer(x) |
|
return x, extra |
|
|
|
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): |
|
""" |
|
Returns: |
|
tuple: |
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
raise NotImplementedError |
|
|
|
def output_layer(self, features, **kwargs): |
|
""" |
|
Project features to the default output size, e.g., vocabulary size. |
|
|
|
Args: |
|
features (Tensor): features returned by *extract_features*. |
|
""" |
|
raise NotImplementedError |
|
|
|
def get_normalized_probs( |
|
self, |
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], |
|
log_probs: bool, |
|
sample: Optional[Dict[str, Tensor]] = None, |
|
): |
|
"""Get normalized probabilities (or log probs) from a net's output.""" |
|
return self.get_normalized_probs_scriptable(net_output, log_probs, sample) |
|
|
|
|
|
|
|
|
|
|
|
def get_normalized_probs_scriptable( |
|
self, |
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], |
|
log_probs: bool, |
|
sample: Optional[Dict[str, Tensor]] = None, |
|
): |
|
"""Get normalized probabilities (or log probs) from a net's output.""" |
|
|
|
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: |
|
if sample is not None: |
|
assert "target" in sample |
|
target = sample["target"] |
|
else: |
|
target = None |
|
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) |
|
return out.exp_() if not log_probs else out |
|
|
|
logits = net_output[0] |
|
if log_probs: |
|
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) |
|
else: |
|
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the decoder.""" |
|
return 1e6 |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
"""Upgrade old state dicts to work with newer code.""" |
|
return state_dict |
|
|
|
def prepare_for_onnx_export_(self): |
|
self.onnx_trace = True |
|
|