Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/models
/composite_encoder.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from .fairseq_encoder import FairseqEncoder | |
class CompositeEncoder(FairseqEncoder): | |
""" | |
A wrapper around a dictionary of :class:`FairseqEncoder` objects. | |
We run forward on each encoder and return a dictionary of outputs. The first | |
encoder's dictionary is used for initialization. | |
Args: | |
encoders (dict): a dictionary of :class:`FairseqEncoder` objects. | |
""" | |
def __init__(self, encoders): | |
super().__init__(next(iter(encoders.values())).dictionary) | |
self.encoders = encoders | |
for key in self.encoders: | |
self.add_module(key, self.encoders[key]) | |
def forward(self, src_tokens, src_lengths): | |
""" | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (LongTensor): lengths of each source sentence of shape | |
`(batch)` | |
Returns: | |
dict: | |
the outputs from each Encoder | |
""" | |
encoder_out = {} | |
for key in self.encoders: | |
encoder_out[key] = self.encoders[key](src_tokens, src_lengths) | |
return encoder_out | |
def reorder_encoder_out(self, encoder_out, new_order): | |
"""Reorder encoder output according to new_order.""" | |
for key in self.encoders: | |
encoder_out[key] = self.encoders[key].reorder_encoder_out( | |
encoder_out[key], new_order | |
) | |
return encoder_out | |
def max_positions(self): | |
return min(self.encoders[key].max_positions() for key in self.encoders) | |
def upgrade_state_dict(self, state_dict): | |
for key in self.encoders: | |
self.encoders[key].upgrade_state_dict(state_dict) | |
return state_dict | |