""" Patches to support multipack for mixtral """ import torch def patch_mixtral_moe_forward_zero3() -> None: import torch.nn.functional as F def mlp_forward(self, hidden_states): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( hidden_states ) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states # Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk( routing_weights, self.top_k, dim=-1, sorted=False ) topk_weight /= topk_weight.sum(dim=-1, keepdim=True) # we cast back to the input dtype topk_weight = topk_weight.to(hidden_states.dtype) hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) y = torch.empty_like(hidden_states) # pylint: disable=invalid-name flat_topk_idx = topk_idx.view(-1) for i in range(self.num_experts): expert = self.experts[i] y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) y = ( # pylint: disable=invalid-name y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1) ).sum(dim=1) final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits from transformers.models.mixtral.modeling_mixtral import ( MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock, ) MixtralBLockSparseTop2MLP.forward = mlp_forward MixtralSparseMoeBlock.forward = moe_forward