Spaces:
Runtime error
Runtime error
OFA-OCR-dedao-demo001
/
fairseq
/examples
/translation_moe
/translation_moe_src
/mean_pool_gating_network.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn.functional as F | |
class MeanPoolGatingNetwork(torch.nn.Module): | |
"""A simple mean-pooling gating network for selecting experts. | |
This module applies mean pooling over an encoder's output and returns | |
reponsibilities for each expert. The encoder format is expected to match | |
:class:`fairseq.models.transformer.TransformerEncoder`. | |
""" | |
def __init__(self, embed_dim, num_experts, dropout=None): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.num_experts = num_experts | |
self.fc1 = torch.nn.Linear(embed_dim, embed_dim) | |
self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None | |
self.fc2 = torch.nn.Linear(embed_dim, num_experts) | |
def forward(self, encoder_out): | |
if not ( | |
"encoder_out" in encoder_out | |
and "encoder_padding_mask" in encoder_out | |
and encoder_out["encoder_out"][0].size(2) == self.embed_dim | |
): | |
raise ValueError("Unexpected format for encoder_out") | |
# mean pooling over time | |
encoder_padding_mask = encoder_out["encoder_padding_mask"][0] # B x T | |
encoder_out = encoder_out["encoder_out"][0].transpose(0, 1) # B x T x C | |
if encoder_padding_mask is not None: | |
encoder_out = encoder_out.clone() # required because of transpose above | |
encoder_out[encoder_padding_mask] = 0 | |
ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True) | |
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out) | |
else: | |
x = torch.mean(encoder_out, dim=1) | |
x = torch.tanh(self.fc1(x)) | |
if self.dropout is not None: | |
x = self.dropout(x) | |
x = self.fc2(x) | |
return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x) | |