| """Pruned FlexOlmo model with variable-width expert 1. |
| |
| This module provides a HuggingFace-compatible model that can be loaded with: |
| AutoModelForCausalLM.from_pretrained("hbfreed/flex-math-8192", trust_remote_code=True) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import FlexOlmoForCausalLM |
| from transformers.models.flex_olmo.modeling_flex_olmo import FlexOlmoMLP |
|
|
| from .configuration_pruned_flex_olmo import PrunedFlexOlmoConfig |
|
|
|
|
| class PrunedFlexOlmoMLP(nn.Module): |
| """Pruned MLP with same interface as FlexOlmoMLP but variable width.""" |
|
|
| def __init__(self, intermediate_size: int, hidden_size: int, act_fn, dtype=torch.bfloat16): |
| super().__init__() |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype) |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype) |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype) |
| self.act_fn = act_fn |
|
|
| def forward(self, x): |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class PrunedFlexOlmoForCausalLM(FlexOlmoForCausalLM): |
| """FlexOlmo with pruned expert 1 for variable-width MoE. |
| |
| Expert 0 remains at full intermediate_size, while expert 1 is pruned |
| to expert_1_intermediate_size specified in the config. |
| """ |
|
|
| config_class = PrunedFlexOlmoConfig |
|
|
| def __init__(self, config: PrunedFlexOlmoConfig): |
| |
| super().__init__(config) |
|
|
| |
| expert_1_width = config.expert_1_intermediate_size |
| hidden_size = config.hidden_size |
|
|
| for layer in self.model.layers: |
| |
| act_fn = layer.mlp.experts[1].act_fn |
|
|
| |
| layer.mlp.experts[1] = PrunedFlexOlmoMLP( |
| intermediate_size=expert_1_width, |
| hidden_size=hidden_size, |
| act_fn=act_fn, |
| dtype=self.dtype, |
| ) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| """Load pruned model, handling both local and hub paths.""" |
| |
| |
| return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
|
|