mixtral_small_dummy / custom_mixtral.py
OsakanaTeishoku's picture
Upload CustomMixtralForCausalLM
3cdcba2 verified
raw
history blame contribute delete
No virus
5.39 kB
from transformers import MixtralForCausalLM, MixtralConfig
from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoECausalLMOutputWithPast
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock, load_balancing_loss_func
from .noisy_gate import NoisyGate
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
def router_z_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2
) -> float:
"""Router z-loss used in ST-MoE."""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
router_z_loss = torch.logsumexp(concatenated_gate_logits, dim = -1)
router_z_loss = torch.square(router_z_loss)
router_z_loss = router_z_loss.mean()
return router_z_loss
class CustomMixtralConfig(MixtralConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class CustomMixtralForCausalLM(MixtralForCausalLM):
"""Mixtral with z-loss. Gating improvement based on ST-MoE."""
def __init__(self, config):
super().__init__(config)
self.router_z_loss_coef = 1e-3
for layer in self.model.layers:
layer.block_sparse_moe.gate = NoisyGate(config.hidden_size, config.num_local_experts, noise_mult=1.0, bias=False)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
router_z_loss = None
if output_router_logits:
router_z_loss = router_z_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
)
if labels is not None:
loss += self.router_z_loss_coef * router_z_loss.to(loss.device)
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (router_z_loss,) + output
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoECausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
z_loss=router_z_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)