import torch import torch.nn as nn import torch.nn.functional as F class top_k_gating(nn.Module): def __init__( self, input_size, num_experts, top_k, ): """ Initialize the top-k gating mechanism. Args: input_size (int): Size of the input. num_experts (int): Number of experts. top_k (int): Number of top experts to select. acc_aux_loss (bool): Whether to accumulate auxiliary loss statistics. dropout (float): Dropout rate for gating network. hidden_size (int): Hidden size of the gating network. sample_topk (int): Number of top-k experts to sample during training. aux_loss (str): Type of auxiliary loss ('mi' or 'switch'). gate_type (str): Type of gating mechanism ('mlp', 'linear', or 'gmm'). """ super().__init__() self.num_experts = num_experts self.input_size = input_size assert top_k <= num_experts self.top_k = top_k self.layer = nn.Linear(input_size, num_experts, bias=False) def extra_repr(self): """ Return extra representation string for the module. """ return 'k={}, num_experts={}'.format( self.top_k, self.num_experts) def compute_aux_loss(self, probs, logits, gates): """ Calculate and return the auxiliary loss based on the accumulated statistics. Args: eps (float): Small epsilon value for numerical stability. Returns: torch.Tensor: The calculated auxiliary loss. """ count = logits.size(0) probs = probs.sum(0) freq = (gates > 0).float().sum(0) lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum() switchloss = self.num_experts * ( F.normalize(probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0) ).sum() zloss = lsesq / count loss = switchloss + 0.1 * zloss return loss def forward(self, x): """ Compute the top-k gating for the input. See paper: https://arxiv.org/abs/1701.06538. Args: x (torch.Tensor): Input tensor with shape [batch_size, input_size]. skip_mask (torch.Tensor): Skip mask tensor (binary) with the same shape as `x`. x: input Tensor with shape [batch_size, input_size] train: a boolean - we only add noise at training time. noise_epsilon: a float Returns: torch.Tensor: Top-k indices. torch.Tensor: Top-k gating values. torch.Tensor: Probability values for each expert. gates: a Tensor with shape [batch_size, num_experts] load: a Tensor with shape [num_experts] """ logits = self.layer(x).float() top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(x) if self.training: probs = torch.softmax(logits, dim=1) zeros = torch.zeros_like(probs) zeros = zeros.to(top_k_gates.dtype) # Convert zeros to match top_k_gates dtype gates = zeros.scatter(1, top_k_indices, top_k_gates) self.loss = self.compute_aux_loss(probs, logits, gates) else: self.loss = 0 return top_k_indices, top_k_gates