JustinLin610
update
8437114
raw history blame
No virus
2.93 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, NamedTuple, Optional
import torch
import torch.nn as nn
from torch import Tensor
EncoderOut = NamedTuple(
"EncoderOut",
[
("encoder_out", Tensor), # T x B x C
("encoder_padding_mask", Optional[Tensor]), # B x T
("encoder_embedding", Optional[Tensor]), # B x T x C
("encoder_states", Optional[List[Tensor]]), # List[T x B x C]
("src_tokens", Optional[Tensor]), # B x T
("src_lengths", Optional[Tensor]), # B x 1
],
)
class FairseqEncoder(nn.Module):
"""Base class for encoders."""
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
def forward(self, src_tokens, src_lengths=None, **kwargs):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
"""
raise NotImplementedError
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
return self.forward_non_torchscript(net_input)
@torch.jit.unused
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens"
}
return self.forward(**encoder_input)
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
raise NotImplementedError
def max_positions(self):
"""Maximum input length supported by the encoder."""
return 1e6 # an arbitrary large number
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade old state dicts to work with newer code."""
return state_dict
def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
def _apply(m):
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
self.apply(_apply)