GiuliannoV commited on
Commit
eb27551
·
verified ·
1 Parent(s): 4603fe4

Release Metis-1.4 base checkpoint

Browse files

Upload BF16 safetensors base checkpoint, tokenizer/config assets, training summary, model card, and custom Metis runtime source.

README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - metis
9
+ - causal-lm
10
+ - base-model
11
+ - mixture-of-recursion
12
+ - static-sequence-mor
13
+ - gqa
14
+ - swiglu
15
+ - bf16
16
+ - fp8-training
17
+ datasets:
18
+ - epfml/FineWeb-HQ
19
+ - HuggingFaceFW/fineweb-edu
20
+ - HuggingFaceTB/finemath
21
+ - deepmind/pg19
22
+ - HuggingFaceTB/smollm-corpus
23
+ - math-ai/TemplateGSM
24
+ ---
25
+
26
+ # Metis-1.4 Base
27
+
28
+ Metis-1.4 Base is the pre/post-instruction base checkpoint for the Metis-1.4 line from Lernex. It is a compact ~504M parameter English causal language model built around a Metis Mixture-of-Recursion transformer: a 19-layer shared stack with up to 3 recursive depth passes, grouped-query attention, RMSNorm, RoPE, SwiGLU MLPs, and a 16k tokenizer.
29
+
30
+ This release is the base model artifact, exported before Chat SFT, Reasoning SFT, reward modeling, and DPO. It is meant as the foundation checkpoint for research, continuation, post-training, and reproducibility. It is not expected to behave like a polished assistant by itself.
31
+
32
+ ## Release Snapshot
33
+
34
+ - Repository: `Lernex/Metis-1.4-base`
35
+ - Artifact: `model.safetensors`
36
+ - Export dtype: BF16 weights
37
+ - Parameters: ~503.8M
38
+ - Context length: 1024 tokens
39
+ - Vocabulary: 16,384 tokens
40
+ - Attention: 24 query heads, 8 KV heads, head dimension 64
41
+ - Base width/depth: d_model 1536, 19 shared layers
42
+ - MoR depth: max 3 recursive passes, static sequence MoR selected for the continued-pretraining checkpoint
43
+ - Final base selection: static sequence MoR continued-pretraining checkpoint at step 5000
44
+ - Validation at selection: train 3.6065, val 3.6026, ppl 36.69
45
+
46
+ ## Training Notes
47
+
48
+ The base training path used a speed-first dense pretraining phase followed by a static sequence MoR continued-pretraining phase. Dense pretraining avoided dynamic routing in the expensive hot path so the H100 could run clean fused kernels; MoR was then reintroduced during continued pretraining as a fixed-shape sequence-level route rather than dynamic token packing.
49
+
50
+ The exported checkpoint comes from the winning continued-pretraining branch. A static block MoR probe was also run from the same dense base checkpoint, but the sequence route won at the step-5000 comparison:
51
+
52
+ | Branch | Step | Train | Val | PPL |
53
+ | --- | ---: | ---: | ---: | ---: |
54
+ | static sequence MoR | 5000 | 3.6065 | 3.6026 | 36.69 |
55
+ | static block MoR | 5000 | 3.6090 | 3.6575 | 38.76 |
56
+
57
+ Approximate data mix:
58
+
59
+ - Base pretraining: FineWeb-HQ, FineWeb-Edu, FineMath, PG19, and Cosmopedia-style educational/explanatory text.
60
+ - Continued pretraining: FineMath, TemplateGSM natural-language solutions, PG19, DCLM-Edu-style explainers, FineWeb-HQ explainers, and Cosmopedia-style filtered text.
61
+
62
+ ## Intended Use
63
+
64
+ This checkpoint is intended for:
65
+
66
+ - Continued pretraining and post-training experiments.
67
+ - Small-model research around efficient recursion/routing strategies.
68
+ - Reproducibility for the Metis-1.4 chat and think variants.
69
+ - Educational experiments where a compact base model is more useful than a large opaque artifact.
70
+
71
+ It is not intended as a safety-aligned public assistant without additional post-training.
72
+
73
+ ## Loading
74
+
75
+ Metis-1.4 uses a custom architecture. The model repository includes the `metis_mamba` runtime source used by this release. In the training workspace, the exported model is loaded with:
76
+
77
+ ```python
78
+ import torch
79
+ from tokenizers import Tokenizer
80
+ from metis_mamba.runtime import load_exported_model, generate_completion
81
+
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ model = load_exported_model("path/to/Metis-1.4-base", device=device)
84
+ tokenizer = Tokenizer.from_file("path/to/Metis-1.4-base/tokenizer.json")
85
+
86
+ print(generate_completion(model, tokenizer, "The fastest way to learn algebra is", device))
87
+ ```
88
+
89
+ For H100 training and fast inference, FlashAttention-3 and NVIDIA Transformer Engine were used in the training stack. CPU and non-Hopper environments fall back to PyTorch attention paths for basic loading/inference, but performance will be much lower.
90
+
91
+ ## Files
92
+
93
+ - `model.safetensors`: BF16 exported weights
94
+ - `config.json`: Metis architecture/configuration
95
+ - `generation_config.json`: default generation settings
96
+ - `tokenizer.json`: tokenizer model
97
+ - `tokenizer_config.json`: tokenizer metadata and chat template
98
+ - `special_tokens_map.json`: special token mapping
99
+ - `training_summary.json`: compact release/training summary
100
+ - `runtime_requirements.txt`: minimal runtime package hints
101
+ - `metis_mamba/`: custom runtime source for this architecture
102
+
103
+ ## Limitations
104
+
105
+ - This is a small ~500M model. It should not be compared to large frontier assistants.
106
+ - It is a base model, not a post-trained chat or reasoning release.
107
+ - The context window is 1024 tokens.
108
+ - It may produce incorrect, biased, unsafe, or nonsensical text.
109
+ - It has not yet been reported against the final Metis-1.4 benchmark suite.
110
+
111
+ ## Citation
112
+
113
+ If you use this checkpoint in a writeup, please cite it as:
114
+
115
+ ```bibtex
116
+ @misc{metis14base2026,
117
+ title = {Metis-1.4 Base},
118
+ author = {Lernex},
119
+ year = {2026},
120
+ publisher = {Hugging Face},
121
+ howpublished = {\url{https://huggingface.co/Lernex/Metis-1.4-base}}
122
+ }
123
+ ```
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MetisMoRLMHeadModel"
4
+ ],
5
+ "model_type": "metis_mor_transformer",
6
+ "name": "Metis-1.4",
7
+ "architecture": "metis_mor_decoder",
8
+ "vocab_size": 16384,
9
+ "block_size": 1024,
10
+ "d_model": 1536,
11
+ "n_layer": 19,
12
+ "n_heads": 24,
13
+ "n_kv_heads": 8,
14
+ "head_dim": 64,
15
+ "intermediate_size": 4096,
16
+ "hidden_act": "swiglu",
17
+ "attn_cfg": {},
18
+ "bos_token_id": 1,
19
+ "eos_token_id": 2,
20
+ "pad_token_id": 0,
21
+ "unk_token_id": 3,
22
+ "rms_norm": true,
23
+ "residual_in_fp32": false,
24
+ "fused_add_norm": false,
25
+ "pad_vocab_size_multiple": 16,
26
+ "tie_embeddings": true,
27
+ "torch_dtype": "bfloat16",
28
+ "estimated_params": 503772163,
29
+ "attention_bias": false,
30
+ "mlp_bias": false,
31
+ "attention_dropout": 0.0,
32
+ "rope_theta": 10000.0,
33
+ "attention_backend": "flash_attention_3",
34
+ "fp8_pad_multiple": 16,
35
+ "mor_max_depth": 3,
36
+ "mor_router_hidden_dim": 256,
37
+ "mor_router_temperature": 1.0,
38
+ "mor_router_aux_loss_coef": 0.01,
39
+ "mor_target_avg_depth": 1.5,
40
+ "effective_layer_count": 57,
41
+ "target_effective_layer_count": 28.5
42
+ }
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": 2,
4
+ "pad_token_id": 0,
5
+ "do_sample": true,
6
+ "temperature": 0.7,
7
+ "top_p": 0.95,
8
+ "max_new_tokens": 256
9
+ }
metis_mamba/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import MetisMambaConfig
2
+ from .model import MetisMoRLMHeadModel
3
+ from .runtime import (
4
+ build_reward_model,
5
+ build_model,
6
+ cosine_lr,
7
+ encode_prompt,
8
+ export_checkpoint_to_dir,
9
+ generate_completion,
10
+ load_checkpoint_model,
11
+ load_exported_model,
12
+ parse_torch_dtype,
13
+ )
14
+
15
+ __all__ = [
16
+ "MetisMambaConfig",
17
+ "MetisMoRLMHeadModel",
18
+ "build_reward_model",
19
+ "build_model",
20
+ "cosine_lr",
21
+ "encode_prompt",
22
+ "export_checkpoint_to_dir",
23
+ "generate_completion",
24
+ "load_checkpoint_model",
25
+ "load_exported_model",
26
+ "parse_torch_dtype",
27
+ ]
metis_mamba/checkpoint_compat.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping
4
+
5
+ import torch
6
+
7
+
8
+ def fuse_legacy_metis_state_dict(
9
+ state_dict: Mapping[str, torch.Tensor],
10
+ *,
11
+ consume_legacy: bool = True,
12
+ ) -> tuple[dict[str, torch.Tensor], list[str]]:
13
+ """Map older unfused Metis checkpoints onto the fused QKV/SwiGLU layout."""
14
+ converted = dict(state_dict)
15
+ conversions: list[str] = []
16
+
17
+ q_suffix = ".self_attn.q_proj.impl.weight"
18
+ for q_key in list(converted):
19
+ if not q_key.endswith(q_suffix):
20
+ continue
21
+ prefix = q_key[: -len("q_proj.impl.weight")]
22
+ k_key = f"{prefix}k_proj.impl.weight"
23
+ v_key = f"{prefix}v_proj.impl.weight"
24
+ fused_key = f"{prefix}qkv_proj.impl.weight"
25
+ if fused_key in converted or k_key not in converted or v_key not in converted:
26
+ continue
27
+ converted[fused_key] = torch.cat([converted[q_key], converted[k_key], converted[v_key]], dim=0)
28
+ conversions.append(f"{prefix}qkv_proj.impl.weight")
29
+
30
+ q_bias_key = f"{prefix}q_proj.impl.bias"
31
+ k_bias_key = f"{prefix}k_proj.impl.bias"
32
+ v_bias_key = f"{prefix}v_proj.impl.bias"
33
+ fused_bias_key = f"{prefix}qkv_proj.impl.bias"
34
+ if (
35
+ fused_bias_key not in converted
36
+ and q_bias_key in converted
37
+ and k_bias_key in converted
38
+ and v_bias_key in converted
39
+ ):
40
+ converted[fused_bias_key] = torch.cat(
41
+ [converted[q_bias_key], converted[k_bias_key], converted[v_bias_key]],
42
+ dim=0,
43
+ )
44
+
45
+ if consume_legacy:
46
+ for key in (q_key, k_key, v_key, q_bias_key, k_bias_key, v_bias_key):
47
+ converted.pop(key, None)
48
+
49
+ gate_suffix = ".mlp.gate_proj.impl.weight"
50
+ for gate_key in list(converted):
51
+ if not gate_key.endswith(gate_suffix):
52
+ continue
53
+ prefix = gate_key[: -len("gate_proj.impl.weight")]
54
+ up_key = f"{prefix}up_proj.impl.weight"
55
+ fused_key = f"{prefix}gate_up_proj.impl.weight"
56
+ if fused_key in converted or up_key not in converted:
57
+ continue
58
+ converted[fused_key] = torch.cat([converted[gate_key], converted[up_key]], dim=0)
59
+ conversions.append(f"{prefix}gate_up_proj.impl.weight")
60
+
61
+ gate_bias_key = f"{prefix}gate_proj.impl.bias"
62
+ up_bias_key = f"{prefix}up_proj.impl.bias"
63
+ fused_bias_key = f"{prefix}gate_up_proj.impl.bias"
64
+ if fused_bias_key not in converted and gate_bias_key in converted and up_bias_key in converted:
65
+ converted[fused_bias_key] = torch.cat([converted[gate_bias_key], converted[up_bias_key]], dim=0)
66
+
67
+ if consume_legacy:
68
+ for key in (gate_key, up_key, gate_bias_key, up_bias_key):
69
+ converted.pop(key, None)
70
+
71
+ norm_suffixes = (".ffn_norm.weight", ".ffn_norm.impl.weight")
72
+ for norm_key in list(converted):
73
+ matching_suffix = next((suffix for suffix in norm_suffixes if norm_key.endswith(suffix)), None)
74
+ if matching_suffix is None:
75
+ continue
76
+ prefix = norm_key[: -len(matching_suffix.lstrip("."))]
77
+ fused_norm_key = f"{prefix}mlp.impl.layer_norm_weight"
78
+ if fused_norm_key not in converted:
79
+ converted[fused_norm_key] = converted[norm_key]
80
+ conversions.append(f"{prefix}mlp.impl.layer_norm_weight")
81
+
82
+ fused_gate_suffix = ".mlp.gate_up_proj.impl.weight"
83
+ for gate_up_key in list(converted):
84
+ if not gate_up_key.endswith(fused_gate_suffix):
85
+ continue
86
+ prefix = gate_up_key[: -len("mlp.gate_up_proj.impl.weight")]
87
+ fused_key = f"{prefix}mlp.impl.fc1_weight"
88
+ if fused_key not in converted:
89
+ converted[fused_key] = converted[gate_up_key]
90
+ conversions.append(f"{prefix}mlp.impl.fc1_weight")
91
+
92
+ gate_up_bias_key = f"{prefix}mlp.gate_up_proj.impl.bias"
93
+ fused_bias_key = f"{prefix}mlp.impl.fc1_bias"
94
+ if fused_bias_key not in converted and gate_up_bias_key in converted:
95
+ converted[fused_bias_key] = converted[gate_up_bias_key]
96
+
97
+ down_suffix = ".mlp.down_proj.impl.weight"
98
+ for down_key in list(converted):
99
+ if not down_key.endswith(down_suffix):
100
+ continue
101
+ prefix = down_key[: -len("mlp.down_proj.impl.weight")]
102
+ fused_key = f"{prefix}mlp.impl.fc2_weight"
103
+ if fused_key not in converted:
104
+ converted[fused_key] = converted[down_key]
105
+ conversions.append(f"{prefix}mlp.impl.fc2_weight")
106
+
107
+ down_bias_key = f"{prefix}mlp.down_proj.impl.bias"
108
+ fused_bias_key = f"{prefix}mlp.impl.fc2_bias"
109
+ if fused_bias_key not in converted and down_bias_key in converted:
110
+ converted[fused_bias_key] = converted[down_bias_key]
111
+
112
+ return converted, conversions
113
+
114
+
115
+ def filter_state_dict_for_model(
116
+ model: torch.nn.Module,
117
+ state_dict: Mapping[str, torch.Tensor],
118
+ ) -> tuple[dict[str, torch.Tensor], list[str]]:
119
+ converted, conversions = fuse_legacy_metis_state_dict(state_dict)
120
+ allowed = set(model.state_dict().keys())
121
+ return {name: tensor for name, tensor in converted.items() if name in allowed}, conversions
metis_mamba/config.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass
4
+ from typing import Any
5
+
6
+
7
+ @dataclass
8
+ class MetisMambaConfig:
9
+ name: str = "Metis-1.4"
10
+ model_type: str = "metis_mor_transformer"
11
+ architecture: str = "metis_mor_decoder"
12
+ vocab_size: int = 16384
13
+ block_size: int = 1024
14
+ d_model: int = 1536
15
+ n_layer: int = 19
16
+ n_heads: int = 24
17
+ n_kv_heads: int = 8
18
+ head_dim: int = 64
19
+ intermediate_size: int = 4096
20
+ hidden_act: str = "swiglu"
21
+ tie_embeddings: bool = True
22
+ rms_norm: bool = True
23
+ residual_in_fp32: bool = False
24
+ fused_add_norm: bool = False
25
+ pad_vocab_size_multiple: int = 16
26
+ initializer_range: float = 0.02
27
+ torch_dtype: str = "bfloat16"
28
+ attention_bias: bool = False
29
+ mlp_bias: bool = False
30
+ attention_dropout: float = 0.0
31
+ rope_theta: float = 10000.0
32
+ attention_backend: str = "auto"
33
+ training_mode: str = "dynamic_token_mor"
34
+ mor_enabled: bool = True
35
+ mor_train_router: bool = True
36
+ mor_runtime_mode: str = "dynamic_token"
37
+ attention_mask_mode: str = "auto"
38
+ disable_depth_stack: bool = False
39
+ disable_token_packing: bool = False
40
+ disable_token_scatter: bool = False
41
+ debug_attention_backend: bool = False
42
+ debug_perf_counters: bool = False
43
+ cuda_graphs: bool = False
44
+ cuda_graph_scope: str = "none"
45
+ fp8_dpa: bool = False
46
+ fp8_mha: bool = False
47
+ te_fused_mlp: bool = False
48
+ lm_loss_impl: str = "standard"
49
+ mor_max_depth: int = 3
50
+ mor_router_hidden_dim: int = 256
51
+ mor_router_temperature: float = 1.0
52
+ mor_router_aux_loss_coef: float = 0.01
53
+ mor_target_avg_depth: float = 1.5
54
+ mor_depth2_capacity_sequences: int = 0
55
+ mor_depth3_capacity_sequences: int = 0
56
+ mor_block_size: int = 128
57
+ mor_depth2_capacity_blocks: int = 0
58
+ mor_depth3_capacity_blocks: int = 0
59
+ block_mor_attention_mode: str = "local_block_refinement"
60
+ fp8_pad_multiple: int = 16
61
+
62
+ @classmethod
63
+ def from_dict(cls, payload: dict[str, Any]) -> "MetisMambaConfig":
64
+ allowed = {field.name for field in cls.__dataclass_fields__.values()}
65
+ cooked = {key: value for key, value in payload.items() if key in allowed}
66
+ return cls(**cooked)
67
+
68
+ def validate(self) -> None:
69
+ if self.block_size <= 0:
70
+ raise ValueError("block_size must be positive.")
71
+ if self.d_model <= 0 or self.n_layer <= 0:
72
+ raise ValueError("d_model and n_layer must be positive.")
73
+ if self.n_heads <= 0 or self.n_kv_heads <= 0:
74
+ raise ValueError("n_heads and n_kv_heads must be positive.")
75
+ if self.n_heads % self.n_kv_heads != 0:
76
+ raise ValueError("n_heads must be divisible by n_kv_heads.")
77
+ if self.head_dim <= 0:
78
+ raise ValueError("head_dim must be positive.")
79
+ if self.n_heads * self.head_dim != self.d_model:
80
+ raise ValueError("n_heads * head_dim must equal d_model.")
81
+ if self.intermediate_size <= 0:
82
+ raise ValueError("intermediate_size must be positive.")
83
+ if self.hidden_act not in {"swiglu"}:
84
+ raise ValueError("hidden_act must currently be swiglu.")
85
+ if self.pad_vocab_size_multiple <= 0:
86
+ raise ValueError("pad_vocab_size_multiple must be positive.")
87
+ if self.attention_backend not in {"auto", "flash_attention_3", "sdpa", "eager"}:
88
+ raise ValueError("attention_backend must be one of: auto, flash_attention_3, sdpa, eager.")
89
+ if self.training_mode not in {
90
+ "dynamic_token_mor",
91
+ "static_dense_pretrain",
92
+ "static_sequence_mor",
93
+ "static_block_mor",
94
+ }:
95
+ raise ValueError(
96
+ "training_mode must be one of: dynamic_token_mor, static_dense_pretrain, "
97
+ "static_sequence_mor, static_block_mor."
98
+ )
99
+ if self.mor_runtime_mode not in {"dynamic_token", "disabled", "static_sequence", "static_block"}:
100
+ raise ValueError("mor_runtime_mode must be one of: dynamic_token, disabled, static_sequence, static_block.")
101
+ if self.attention_mask_mode not in {"auto", "causal_none"}:
102
+ raise ValueError("attention_mask_mode must be one of: auto, causal_none.")
103
+ if self.cuda_graph_scope not in {"none", "step", "microbatch"}:
104
+ raise ValueError("cuda_graph_scope must be one of: none, step, microbatch.")
105
+ if self.lm_loss_impl not in {"standard", "liger_fused_linear_ce"}:
106
+ raise ValueError("lm_loss_impl must be one of: standard, liger_fused_linear_ce.")
107
+ if self.rope_theta <= 0:
108
+ raise ValueError("rope_theta must be positive.")
109
+ if self.mor_max_depth <= 0:
110
+ raise ValueError("mor_max_depth must be positive.")
111
+ if self.mor_router_hidden_dim <= 0:
112
+ raise ValueError("mor_router_hidden_dim must be positive.")
113
+ if self.mor_router_temperature <= 0:
114
+ raise ValueError("mor_router_temperature must be positive.")
115
+ if self.mor_router_aux_loss_coef < 0:
116
+ raise ValueError("mor_router_aux_loss_coef cannot be negative.")
117
+ if self.mor_target_avg_depth < 1.0 or self.mor_target_avg_depth > float(self.mor_max_depth):
118
+ raise ValueError("mor_target_avg_depth must be within [1, mor_max_depth].")
119
+ if self.mor_depth2_capacity_sequences < 0 or self.mor_depth3_capacity_sequences < 0:
120
+ raise ValueError("sequence MoR capacities cannot be negative.")
121
+ if self.mor_block_size <= 0:
122
+ raise ValueError("mor_block_size must be positive.")
123
+ if self.mor_depth2_capacity_blocks < 0 or self.mor_depth3_capacity_blocks < 0:
124
+ raise ValueError("block MoR capacities cannot be negative.")
125
+ if self.block_mor_attention_mode not in {"local_block_refinement"}:
126
+ raise ValueError("block_mor_attention_mode must currently be local_block_refinement.")
127
+ if self.fp8_pad_multiple <= 0:
128
+ raise ValueError("fp8_pad_multiple must be positive.")
129
+
130
+ @property
131
+ def padded_vocab_size(self) -> int:
132
+ if self.vocab_size % self.pad_vocab_size_multiple == 0:
133
+ return self.vocab_size
134
+ return self.vocab_size + (
135
+ self.pad_vocab_size_multiple - (self.vocab_size % self.pad_vocab_size_multiple)
136
+ )
137
+
138
+ @property
139
+ def uses_mor(self) -> bool:
140
+ return self.mor_enabled and self.training_mode != "static_dense_pretrain"
141
+
142
+ @property
143
+ def uses_dynamic_token_mor(self) -> bool:
144
+ return self.uses_mor and self.training_mode == "dynamic_token_mor"
145
+
146
+ @property
147
+ def uses_static_sequence_mor(self) -> bool:
148
+ return self.uses_mor and self.training_mode == "static_sequence_mor"
149
+
150
+ @property
151
+ def uses_static_block_mor(self) -> bool:
152
+ return self.uses_mor and self.training_mode == "static_block_mor"
153
+
154
+ @property
155
+ def effective_layer_count(self) -> int:
156
+ if not self.uses_mor:
157
+ return self.n_layer
158
+ return self.n_layer * self.mor_max_depth
159
+
160
+ @property
161
+ def target_effective_layer_count(self) -> float:
162
+ if not self.uses_mor:
163
+ return float(self.n_layer)
164
+ return self.n_layer * self.mor_target_avg_depth
165
+
166
+ @property
167
+ def attn_cfg(self) -> dict[str, Any]:
168
+ return {
169
+ "causal": True,
170
+ "head_dim": self.head_dim,
171
+ "num_heads": self.n_heads,
172
+ "num_heads_kv": self.n_kv_heads,
173
+ "attention_bias": self.attention_bias,
174
+ "dropout": self.attention_dropout,
175
+ "rope_theta": self.rope_theta,
176
+ "backend": self.attention_backend,
177
+ }
178
+
179
+ def estimate_params(self) -> int:
180
+ embed_params = self.padded_vocab_size * self.d_model
181
+ q_proj = self.d_model * (self.n_heads * self.head_dim)
182
+ kv_proj = self.d_model * (2 * self.n_kv_heads * self.head_dim)
183
+ out_proj = (self.n_heads * self.head_dim) * self.d_model
184
+ attn_block = q_proj + kv_proj + out_proj
185
+ mlp_block = 3 * self.d_model * self.intermediate_size
186
+ norm_modules = (2 * self.n_layer) + 1 + (1 if self.uses_mor else 0)
187
+ norm_params = norm_modules * self.d_model
188
+ router_params = 0
189
+ if self.uses_mor:
190
+ router_params += (
191
+ self.d_model * self.mor_router_hidden_dim
192
+ + self.mor_router_hidden_dim
193
+ + self.mor_router_hidden_dim * self.mor_max_depth
194
+ + self.mor_max_depth
195
+ )
196
+ total = embed_params + (self.n_layer * (attn_block + mlp_block)) + norm_params + router_params
197
+ return int(total)
198
+
199
+ def to_dict(self) -> dict[str, Any]:
200
+ payload = asdict(self)
201
+ payload["estimated_params"] = self.estimate_params()
202
+ payload["attn_cfg"] = self.attn_cfg
203
+ payload["effective_layer_count"] = self.effective_layer_count
204
+ payload["target_effective_layer_count"] = self.target_effective_layer_count
205
+ return payload
metis_mamba/fp8.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from contextlib import nullcontext
4
+ from functools import lru_cache
5
+ from inspect import signature
6
+
7
+ from torch import nn
8
+
9
+
10
+ @lru_cache(maxsize=1)
11
+ def _load_transformer_engine():
12
+ import transformer_engine.pytorch as te # type: ignore
13
+ from transformer_engine.common.recipe import DelayedScaling, Format # type: ignore
14
+
15
+ return te, DelayedScaling, Format
16
+
17
+
18
+ def transformer_engine_is_available() -> bool:
19
+ try:
20
+ _load_transformer_engine()
21
+ except Exception:
22
+ return False
23
+ return True
24
+
25
+
26
+ def build_fp8_recipe(
27
+ *,
28
+ format_name: str = "HYBRID",
29
+ margin: int = 0,
30
+ amax_history_len: int = 16,
31
+ amax_compute_algo: str = "max",
32
+ fp8_dpa: bool = False,
33
+ fp8_mha: bool = False,
34
+ ):
35
+ _te, DelayedScaling, Format = _load_transformer_engine()
36
+ fp8_format = getattr(Format, format_name)
37
+ kwargs = {
38
+ "margin": margin,
39
+ "fp8_format": fp8_format,
40
+ "amax_history_len": amax_history_len,
41
+ "amax_compute_algo": amax_compute_algo,
42
+ }
43
+ try:
44
+ recipe_params = signature(DelayedScaling).parameters
45
+ except (TypeError, ValueError):
46
+ recipe_params = {}
47
+ if fp8_dpa and "fp8_dpa" in recipe_params:
48
+ kwargs["fp8_dpa"] = True
49
+ if fp8_mha and "fp8_mha" in recipe_params:
50
+ kwargs["fp8_mha"] = True
51
+ return DelayedScaling(**kwargs)
52
+
53
+
54
+ def fp8_autocast_context(
55
+ *,
56
+ enabled: bool,
57
+ recipe=None,
58
+ fp8_group=None,
59
+ ):
60
+ if not enabled:
61
+ return nullcontext()
62
+ te, _DelayedScaling, _Format = _load_transformer_engine()
63
+ return te.fp8_autocast(enabled=True, fp8_recipe=recipe, fp8_group=fp8_group)
64
+
65
+
66
+ def build_linear(
67
+ *,
68
+ in_features: int,
69
+ out_features: int,
70
+ bias: bool,
71
+ use_fp8: bool,
72
+ ):
73
+ if use_fp8 and supports_fp8_linear(in_features, out_features):
74
+ te, _DelayedScaling, _Format = _load_transformer_engine()
75
+ return te.Linear(in_features, out_features, bias=bias)
76
+ return nn.Linear(in_features, out_features, bias=bias)
77
+
78
+
79
+ def build_rmsnorm(
80
+ *,
81
+ hidden_size: int,
82
+ eps: float,
83
+ use_fp8: bool,
84
+ ):
85
+ if use_fp8:
86
+ te, _DelayedScaling, _Format = _load_transformer_engine()
87
+ return te.RMSNorm(hidden_size, eps=eps)
88
+ return None
89
+
90
+
91
+ def build_layernorm_mlp(
92
+ *,
93
+ hidden_size: int,
94
+ ffn_hidden_size: int,
95
+ eps: float,
96
+ bias: bool,
97
+ use_fp8: bool,
98
+ ):
99
+ if use_fp8:
100
+ te, _DelayedScaling, _Format = _load_transformer_engine()
101
+ return te.LayerNormMLP(
102
+ hidden_size,
103
+ ffn_hidden_size,
104
+ eps=eps,
105
+ bias=bias,
106
+ normalization="RMSNorm",
107
+ activation="swiglu",
108
+ device="cpu",
109
+ )
110
+ return None
111
+
112
+
113
+ def is_transformer_engine_module(module: nn.Module) -> bool:
114
+ return module.__class__.__module__.startswith("transformer_engine.")
115
+
116
+
117
+ def supports_fp8_linear(in_features: int, out_features: int) -> bool:
118
+ return (in_features % 16 == 0) and (out_features % 16 == 0)
metis_mamba/hybrid_runtime.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from safetensors.torch import load_file as load_safetensors
11
+ from torch import nn
12
+
13
+
14
+ DEFAULT_ATTN_LAYER_IDX = [3, 7, 11, 15, 19, 23, 27]
15
+
16
+
17
+ @dataclass
18
+ class MetisHybridConfig:
19
+ name: str = "Metis-1.3"
20
+ model_type: str = "metis_mamba2_hybrid"
21
+ architecture: str = "mamba2_hybrid_decoder"
22
+ vocab_size: int = 8192
23
+ block_size: int = 4096
24
+ d_model: int = 1152
25
+ n_layer: int = 28
26
+ n_heads: int = 18
27
+ n_kv_heads: int = 6
28
+ head_dim: int = 64
29
+ attn_layer_idx: list[int] = field(default_factory=lambda: list(DEFAULT_ATTN_LAYER_IDX))
30
+ attn_d_conv: int = 4
31
+ attn_rotary_emb_dim: int = 0
32
+ ssm_layer: str = "Mamba2"
33
+ ssm_d_state: int = 64
34
+ ssm_d_conv: int = 4
35
+ ssm_expand: int = 2
36
+ tie_embeddings: bool = True
37
+ rms_norm: bool = True
38
+ residual_in_fp32: bool = False
39
+ fused_add_norm: bool = False
40
+ pad_vocab_size_multiple: int = 16
41
+ initializer_range: float = 0.02
42
+ torch_dtype: str = "bfloat16"
43
+
44
+ @classmethod
45
+ def from_dict(cls, payload: dict[str, Any]) -> "MetisHybridConfig":
46
+ allowed = {field.name for field in cls.__dataclass_fields__.values()}
47
+ return cls(**{key: value for key, value in payload.items() if key in allowed})
48
+
49
+ @property
50
+ def padded_vocab_size(self) -> int:
51
+ if self.vocab_size % self.pad_vocab_size_multiple == 0:
52
+ return self.vocab_size
53
+ return self.vocab_size + self.pad_vocab_size_multiple - (self.vocab_size % self.pad_vocab_size_multiple)
54
+
55
+
56
+ @dataclass
57
+ class HybridCausalLMOutput:
58
+ logits: torch.Tensor
59
+
60
+
61
+ class HybridRMSNorm(nn.Module):
62
+ def __init__(self, hidden_size: int, eps: float = 1e-5) -> None:
63
+ super().__init__()
64
+ self.weight = nn.Parameter(torch.ones(hidden_size))
65
+ self.eps = eps
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ dtype = x.dtype
69
+ x_float = x.float()
70
+ rstd = torch.rsqrt(x_float.square().mean(dim=-1, keepdim=True) + self.eps)
71
+ return (x_float * rstd * self.weight.float()).to(dtype)
72
+
73
+
74
+ class HybridGatedRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size: int, *, eps: float = 1e-5, group_size: int | None = None) -> None:
76
+ super().__init__()
77
+ self.weight = nn.Parameter(torch.ones(hidden_size))
78
+ self.eps = eps
79
+ self.group_size = group_size
80
+
81
+ def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
82
+ dtype = x.dtype
83
+ x_float = x.float() * F.silu(gate.float())
84
+ if self.group_size and self.group_size < x_float.shape[-1]:
85
+ original_shape = x_float.shape
86
+ x_group = x_float.reshape(*original_shape[:-1], -1, self.group_size)
87
+ rstd = torch.rsqrt(x_group.square().mean(dim=-1, keepdim=True) + self.eps)
88
+ x_float = (x_group * rstd).reshape(original_shape)
89
+ else:
90
+ rstd = torch.rsqrt(x_float.square().mean(dim=-1, keepdim=True) + self.eps)
91
+ x_float = x_float * rstd
92
+ return (x_float * self.weight.float()).to(dtype)
93
+
94
+
95
+ class HybridMamba2(nn.Module):
96
+ def __init__(self, config: MetisHybridConfig, *, layer_idx: int, dtype: torch.dtype | None = None) -> None:
97
+ super().__init__()
98
+ self.layer_idx = layer_idx
99
+ self.d_model = config.d_model
100
+ self.d_state = config.ssm_d_state
101
+ self.d_conv = config.ssm_d_conv
102
+ self.expand = config.ssm_expand
103
+ self.d_inner = self.expand * self.d_model
104
+ self.headdim = config.head_dim
105
+ self.d_ssm = self.d_inner
106
+ self.ngroups = 1
107
+ self.nheads = self.d_ssm // self.headdim
108
+ self.activation = "silu"
109
+
110
+ factory_kwargs = {"dtype": dtype}
111
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
112
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
113
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=False, **factory_kwargs)
114
+ self.conv1d = nn.Conv1d(
115
+ conv_dim,
116
+ conv_dim,
117
+ kernel_size=self.d_conv,
118
+ padding=self.d_conv - 1,
119
+ groups=conv_dim,
120
+ bias=True,
121
+ **factory_kwargs,
122
+ )
123
+ self.dt_bias = nn.Parameter(torch.zeros(self.nheads, **factory_kwargs))
124
+ self.A_log = nn.Parameter(torch.zeros(self.nheads, **factory_kwargs))
125
+ self.D = nn.Parameter(torch.ones(self.nheads))
126
+ self.norm = HybridGatedRMSNorm(self.d_ssm, group_size=self.d_ssm // self.ngroups)
127
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False, **factory_kwargs)
128
+
129
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
130
+ batch_size, seq_len, _ = hidden_states.shape
131
+ dtype = hidden_states.dtype
132
+ zxbcdt = self.in_proj(hidden_states)
133
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
134
+ z0, x0, z, xbc, dt = torch.split(
135
+ zxbcdt,
136
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
137
+ dim=-1,
138
+ )
139
+
140
+ xbc = self.conv1d(xbc.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
141
+ xbc = F.silu(xbc)
142
+ x, b_vec, c_vec = torch.split(
143
+ xbc,
144
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
145
+ dim=-1,
146
+ )
147
+ x = x.reshape(batch_size, seq_len, self.nheads, self.headdim).float()
148
+ b_vec = b_vec.reshape(batch_size, seq_len, self.ngroups, self.d_state).float()
149
+ c_vec = c_vec.reshape(batch_size, seq_len, self.ngroups, self.d_state).float()
150
+ dt = F.softplus(dt.float() + self.dt_bias.float())
151
+ a = -torch.exp(self.A_log.float())
152
+ d = self.D.float()
153
+
154
+ state = torch.zeros(
155
+ batch_size,
156
+ self.nheads,
157
+ self.headdim,
158
+ self.d_state,
159
+ device=hidden_states.device,
160
+ dtype=torch.float32,
161
+ )
162
+ y_steps: list[torch.Tensor] = []
163
+ heads_per_group = self.nheads // self.ngroups
164
+ head_groups = torch.arange(self.nheads, device=hidden_states.device) // heads_per_group
165
+ for index in range(seq_len):
166
+ dt_t = dt[:, index]
167
+ x_t = x[:, index]
168
+ b_t = b_vec[:, index].index_select(1, head_groups)
169
+ c_t = c_vec[:, index].index_select(1, head_groups)
170
+ state = state * torch.exp(dt_t[:, :, None, None] * a[None, :, None, None])
171
+ state = state + dt_t[:, :, None, None] * x_t[:, :, :, None] * b_t[:, :, None, :]
172
+ y_t = torch.einsum("bhpn,bhn->bhp", state, c_t)
173
+ y_t = y_t + d[None, :, None] * x_t
174
+ y_steps.append(y_t.reshape(batch_size, self.d_ssm))
175
+
176
+ y = torch.stack(y_steps, dim=1).to(dtype)
177
+ y = self.norm(y, z)
178
+ if d_mlp > 0:
179
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
180
+ return self.out_proj(y)
181
+
182
+
183
+ class HybridMHA(nn.Module):
184
+ def __init__(self, config: MetisHybridConfig, *, layer_idx: int, dtype: torch.dtype | None = None) -> None:
185
+ super().__init__()
186
+ self.layer_idx = layer_idx
187
+ self.num_heads = config.n_heads
188
+ self.num_kv_heads = config.n_kv_heads
189
+ self.head_dim = config.head_dim
190
+ self.d_conv = config.attn_d_conv
191
+ self.softmax_scale = self.head_dim ** -0.5
192
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_kv_heads)
193
+ factory_kwargs = {"dtype": dtype}
194
+ self.in_proj = nn.Linear(config.d_model, qkv_dim, bias=False, **factory_kwargs)
195
+ self.conv1d = nn.Conv1d(
196
+ qkv_dim,
197
+ qkv_dim,
198
+ kernel_size=self.d_conv,
199
+ padding=self.d_conv - 1,
200
+ groups=qkv_dim,
201
+ bias=True,
202
+ **factory_kwargs,
203
+ )
204
+ self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.d_model, bias=False, **factory_kwargs)
205
+
206
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
207
+ batch_size, seq_len, _ = hidden_states.shape
208
+ qkv = self.in_proj(hidden_states)
209
+ qkv = self.conv1d(qkv.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
210
+ q, kv = torch.split(
211
+ qkv,
212
+ [self.num_heads * self.head_dim, 2 * self.num_kv_heads * self.head_dim],
213
+ dim=-1,
214
+ )
215
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
216
+ kv = kv.view(batch_size, seq_len, 2, self.num_kv_heads, self.head_dim)
217
+ k = kv[:, :, 0].transpose(1, 2)
218
+ v = kv[:, :, 1].transpose(1, 2)
219
+ repeats = self.num_heads // self.num_kv_heads
220
+ k = torch.repeat_interleave(k, repeats=repeats, dim=1)
221
+ v = torch.repeat_interleave(v, repeats=repeats, dim=1)
222
+ context = F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=self.softmax_scale)
223
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)
224
+ return self.out_proj(context)
225
+
226
+
227
+ class HybridBlock(nn.Module):
228
+ def __init__(self, config: MetisHybridConfig, *, layer_idx: int, dtype: torch.dtype | None = None) -> None:
229
+ super().__init__()
230
+ self.norm = HybridRMSNorm(config.d_model)
231
+ if layer_idx in set(config.attn_layer_idx):
232
+ self.mixer = HybridMHA(config, layer_idx=layer_idx, dtype=dtype)
233
+ else:
234
+ self.mixer = HybridMamba2(config, layer_idx=layer_idx, dtype=dtype)
235
+
236
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor | None) -> tuple[torch.Tensor, torch.Tensor]:
237
+ residual = hidden_states + residual if residual is not None else hidden_states
238
+ hidden_states = self.norm(residual)
239
+ hidden_states = self.mixer(hidden_states)
240
+ return hidden_states, residual
241
+
242
+
243
+ class HybridBackbone(nn.Module):
244
+ def __init__(self, config: MetisHybridConfig, *, dtype: torch.dtype | None = None) -> None:
245
+ super().__init__()
246
+ self.embedding = nn.Embedding(config.padded_vocab_size, config.d_model, dtype=dtype)
247
+ self.layers = nn.ModuleList([HybridBlock(config, layer_idx=index, dtype=dtype) for index in range(config.n_layer)])
248
+ self.norm_f = HybridRMSNorm(config.d_model)
249
+
250
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
251
+ hidden_states = self.embedding(input_ids)
252
+ residual = None
253
+ for layer in self.layers:
254
+ hidden_states, residual = layer(hidden_states, residual)
255
+ residual = hidden_states + residual if residual is not None else hidden_states
256
+ return self.norm_f(residual)
257
+
258
+
259
+ class MetisMamba2HybridLMHeadModel(nn.Module):
260
+ def __init__(self, config: MetisHybridConfig, *, dtype: torch.dtype | None = None) -> None:
261
+ super().__init__()
262
+ self.config = config
263
+ self.model_family = config.model_type
264
+ self.backbone = HybridBackbone(config, dtype=dtype)
265
+ self.lm_head = nn.Linear(config.d_model, config.padded_vocab_size, bias=False, dtype=dtype)
266
+ if config.tie_embeddings:
267
+ self.lm_head.weight = self.backbone.embedding.weight
268
+
269
+ def forward(self, input_ids: torch.Tensor) -> HybridCausalLMOutput:
270
+ hidden_states = self.backbone(input_ids)
271
+ return HybridCausalLMOutput(logits=self.lm_head(hidden_states))
272
+
273
+
274
+ def _load_config(model_dir: Path) -> MetisHybridConfig:
275
+ return MetisHybridConfig.from_dict(json.loads((model_dir / "config.json").read_text()))
276
+
277
+
278
+ def load_hybrid_exported_model(model_dir: str | Path, device: torch.device) -> MetisMamba2HybridLMHeadModel:
279
+ model_dir = Path(model_dir)
280
+ config = _load_config(model_dir)
281
+ model = MetisMamba2HybridLMHeadModel(config, dtype=torch.float32)
282
+ state_dict = load_safetensors(str(model_dir / "model.safetensors"), device="cpu")
283
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
284
+ allowed_missing = {"lm_head.weight"} if config.tie_embeddings else set()
285
+ if set(missing) - allowed_missing or unexpected:
286
+ raise RuntimeError(f"Unexpected Metis-1.3 load result: missing={missing}, unexpected={unexpected}")
287
+ model.to(device)
288
+ model.eval()
289
+ return model
290
+
metis_mamba/model.py ADDED
@@ -0,0 +1,1353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from contextlib import nullcontext
4
+ from dataclasses import dataclass
5
+ from functools import lru_cache
6
+ from inspect import signature
7
+ import os
8
+ from typing import Any
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from .config import MetisMambaConfig
15
+ from .fp8 import (
16
+ build_layernorm_mlp,
17
+ build_linear,
18
+ build_rmsnorm,
19
+ fp8_autocast_context,
20
+ is_transformer_engine_module,
21
+ )
22
+
23
+
24
+ @lru_cache(maxsize=1)
25
+ def _load_liger_fused_linear_ce():
26
+ from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss # type: ignore
27
+
28
+ return LigerFusedLinearCrossEntropyLoss
29
+
30
+
31
+ def _nvtx_range(name: str):
32
+ if torch.cuda.is_available() and hasattr(torch.cuda, "nvtx"):
33
+ return torch.cuda.nvtx.range(name)
34
+ return nullcontext()
35
+
36
+
37
+ @lru_cache(maxsize=1)
38
+ def _load_flash_attention_3():
39
+ errors: list[str] = []
40
+ try:
41
+ from flash_attn_3.flash_attn_interface import flash_attn_func # type: ignore
42
+
43
+ return flash_attn_func
44
+ except ImportError as exc:
45
+ errors.append(str(exc))
46
+
47
+ try:
48
+ import flash_attn_interface # type: ignore
49
+
50
+ flash_attn_func = getattr(flash_attn_interface, "flash_attn_func", None)
51
+ if flash_attn_func is not None:
52
+ return flash_attn_func
53
+ errors.append("flash_attn_interface imported but does not expose flash_attn_func")
54
+ except ImportError as exc:
55
+ errors.append(str(exc))
56
+
57
+ try:
58
+ from hopper.flash_attn_interface import flash_attn_func # type: ignore
59
+
60
+ return flash_attn_func
61
+ except ImportError as exc:
62
+ errors.append(str(exc))
63
+
64
+ raise RuntimeError(
65
+ "FlashAttention-3 is not available. Install the official Dao-AILab hopper package "
66
+ "(cd flash-attention/hopper && python setup.py install). "
67
+ f"Import errors: {' | '.join(errors)}"
68
+ )
69
+
70
+
71
+ @lru_cache(maxsize=1)
72
+ def _flash_attention_3_accepts_dropout_p() -> bool:
73
+ try:
74
+ return "dropout_p" in signature(_load_flash_attention_3()).parameters
75
+ except (TypeError, ValueError):
76
+ return True
77
+
78
+
79
+ def _is_hopper_device(device: torch.device) -> bool:
80
+ if device.type != "cuda":
81
+ return False
82
+ major, _minor = torch.cuda.get_device_capability(device)
83
+ return major >= 9
84
+
85
+
86
+ @dataclass
87
+ class MetisCausalLMOutput:
88
+ logits: torch.Tensor | None
89
+ loss: torch.Tensor | None = None
90
+ lm_loss: torch.Tensor | None = None
91
+ hidden_states: list[torch.Tensor] | None = None
92
+ route_probs: torch.Tensor | None = None
93
+ chosen_depths: torch.Tensor | None = None
94
+ route_aux_loss: torch.Tensor | None = None
95
+ mean_depth: torch.Tensor | None = None
96
+ active_token_ratios: torch.Tensor | None = None
97
+
98
+
99
+ @dataclass
100
+ class MetisRewardOutput:
101
+ rewards: torch.Tensor
102
+ loss: torch.Tensor | None = None
103
+ route_aux_loss: torch.Tensor | None = None
104
+ mean_depth: torch.Tensor | None = None
105
+ active_token_ratios: torch.Tensor | None = None
106
+
107
+
108
+ class MetisRMSNorm(nn.Module):
109
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
110
+ super().__init__()
111
+ self.weight = nn.Parameter(torch.ones(hidden_size))
112
+ self.eps = eps
113
+
114
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
115
+ input_dtype = hidden_states.dtype
116
+ hidden_states = hidden_states.to(torch.float32)
117
+ variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
118
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
119
+ return self.weight * hidden_states.to(input_dtype)
120
+
121
+
122
+ class MetisLinear(nn.Module):
123
+ def __init__(
124
+ self,
125
+ in_features: int,
126
+ out_features: int,
127
+ *,
128
+ bias: bool,
129
+ use_fp8: bool,
130
+ ) -> None:
131
+ super().__init__()
132
+ self.impl = build_linear(
133
+ in_features=in_features,
134
+ out_features=out_features,
135
+ bias=bias,
136
+ use_fp8=use_fp8,
137
+ )
138
+ self.uses_transformer_engine = is_transformer_engine_module(self.impl)
139
+
140
+ @property
141
+ def weight(self):
142
+ return self.impl.weight
143
+
144
+ @weight.setter
145
+ def weight(self, value) -> None:
146
+ self.impl.weight = value
147
+
148
+ def forward(self, hidden_states: torch.Tensor, *, is_first_microbatch: bool | None = None) -> torch.Tensor:
149
+ if self.uses_transformer_engine:
150
+ return self.impl(hidden_states, is_first_microbatch=is_first_microbatch)
151
+ return self.impl(hidden_states)
152
+
153
+
154
+ def build_rms_norm_module(hidden_size: int, *, eps: float, use_fp8: bool) -> nn.Module:
155
+ te_norm = build_rmsnorm(hidden_size=hidden_size, eps=eps, use_fp8=use_fp8)
156
+ if te_norm is not None:
157
+ return te_norm
158
+ return MetisRMSNorm(hidden_size=hidden_size, eps=eps)
159
+
160
+
161
+ class MetisRotaryEmbedding(nn.Module):
162
+ def __init__(self, dim: int, *, base: float) -> None:
163
+ super().__init__()
164
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
165
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
166
+ self.dim = dim
167
+
168
+ def forward(self, position_ids: torch.Tensor, *, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
169
+ pos = position_ids.to(self.inv_freq.device, dtype=torch.float32)
170
+ freqs = torch.einsum("bl,d->bld", pos, self.inv_freq)
171
+ emb = torch.cat((freqs, freqs), dim=-1)
172
+ return emb.cos().to(dtype=dtype), emb.sin().to(dtype=dtype)
173
+
174
+
175
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
176
+ first = x[..., ::2]
177
+ second = x[..., 1::2]
178
+ return torch.stack((-second, first), dim=-1).flatten(start_dim=-2)
179
+
180
+
181
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
182
+ cos = cos.unsqueeze(1)
183
+ sin = sin.unsqueeze(1)
184
+ return (x * cos) + (rotate_half(x) * sin)
185
+
186
+
187
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
188
+ if n_rep == 1:
189
+ return hidden_states
190
+ batch_size, num_kv_heads, seq_len, head_dim = hidden_states.shape
191
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch_size, num_kv_heads, n_rep, seq_len, head_dim)
192
+ return hidden_states.reshape(batch_size, num_kv_heads * n_rep, seq_len, head_dim)
193
+
194
+
195
+ def build_causal_mask(
196
+ *,
197
+ batch_size: int,
198
+ seq_len: int,
199
+ device: torch.device,
200
+ dtype: torch.dtype,
201
+ attention_mask: torch.Tensor | None,
202
+ ) -> torch.Tensor:
203
+ causal = torch.full((seq_len, seq_len), torch.finfo(dtype).min, device=device, dtype=dtype)
204
+ causal = torch.triu(causal, diagonal=1).unsqueeze(0).unsqueeze(0)
205
+ if attention_mask is None:
206
+ return causal
207
+ if attention_mask.dim() != 2:
208
+ raise ValueError("attention_mask must have shape (batch, seq_len).")
209
+ expanded = causal.expand(batch_size, 1, seq_len, seq_len).clone()
210
+ key_padding = ~attention_mask.to(torch.bool)
211
+ return expanded.masked_fill(key_padding[:, None, None, :], torch.finfo(dtype).min)
212
+
213
+
214
+ def pack_active_tokens(
215
+ hidden_states: torch.Tensor,
216
+ position_ids: torch.Tensor,
217
+ active_mask: torch.Tensor,
218
+ *,
219
+ pad_multiple: int = 1,
220
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None:
221
+ lengths = active_mask.sum(dim=-1)
222
+ max_active = int(lengths.max().item())
223
+ if max_active == 0:
224
+ return None
225
+ if pad_multiple > 1 and (max_active % pad_multiple) != 0:
226
+ max_active = max_active + (pad_multiple - (max_active % pad_multiple))
227
+ batch_size, _seq_len, hidden_size = hidden_states.shape
228
+ packed_hidden_rows: list[torch.Tensor] = []
229
+ packed_position_rows: list[torch.Tensor] = []
230
+ packed_mask_rows: list[torch.Tensor] = []
231
+ packed_index_rows: list[torch.Tensor] = []
232
+ for batch_index in range(batch_size):
233
+ indices = torch.nonzero(active_mask[batch_index], as_tuple=False).squeeze(-1)
234
+ active_count = int(indices.numel())
235
+ selected_hidden = hidden_states[batch_index].index_select(0, indices) if active_count > 0 else hidden_states.new_zeros((0, hidden_size))
236
+ selected_positions = position_ids[batch_index].index_select(0, indices) if active_count > 0 else position_ids.new_zeros((0,))
237
+ selected_mask = torch.ones(active_count, dtype=torch.bool, device=active_mask.device)
238
+ if active_count < max_active:
239
+ pad_amount = max_active - active_count
240
+ selected_hidden = torch.cat(
241
+ [selected_hidden, hidden_states.new_zeros((pad_amount, hidden_size))],
242
+ dim=0,
243
+ )
244
+ selected_positions = torch.cat(
245
+ [selected_positions, position_ids.new_zeros((pad_amount,))],
246
+ dim=0,
247
+ )
248
+ indices = torch.cat(
249
+ [indices, position_ids.new_full((pad_amount,), -1)],
250
+ dim=0,
251
+ )
252
+ selected_mask = torch.cat(
253
+ [selected_mask, torch.zeros(pad_amount, dtype=torch.bool, device=active_mask.device)],
254
+ dim=0,
255
+ )
256
+ packed_hidden_rows.append(selected_hidden)
257
+ packed_position_rows.append(selected_positions)
258
+ packed_mask_rows.append(selected_mask)
259
+ packed_index_rows.append(indices)
260
+ packed_hidden = torch.stack(packed_hidden_rows, dim=0)
261
+ packed_positions = torch.stack(packed_position_rows, dim=0)
262
+ packed_mask = torch.stack(packed_mask_rows, dim=0)
263
+ packed_indices = torch.stack(packed_index_rows, dim=0)
264
+ return packed_hidden, packed_positions, packed_mask, packed_indices
265
+
266
+
267
+ def scatter_active_tokens(
268
+ full_hidden_states: torch.Tensor,
269
+ packed_hidden_states: torch.Tensor,
270
+ packed_mask: torch.Tensor,
271
+ packed_indices: torch.Tensor,
272
+ active_mask: torch.Tensor,
273
+ ) -> torch.Tensor:
274
+ scatter_indices = packed_indices.clamp_min(0).unsqueeze(-1).expand_as(packed_hidden_states)
275
+ scatter_src = packed_hidden_states * packed_mask.unsqueeze(-1).to(dtype=packed_hidden_states.dtype)
276
+ scattered = torch.zeros_like(full_hidden_states).scatter_add(1, scatter_indices, scatter_src)
277
+ return torch.where(active_mask.unsqueeze(-1), scattered, full_hidden_states)
278
+
279
+
280
+ class MetisSelfAttention(nn.Module):
281
+ def __init__(self, config: MetisMambaConfig, *, use_fp8: bool) -> None:
282
+ super().__init__()
283
+ self.hidden_size = config.d_model
284
+ self.num_heads = config.n_heads
285
+ self.num_kv_heads = config.n_kv_heads
286
+ self.head_dim = config.head_dim
287
+ self.q_dim = self.num_heads * self.head_dim
288
+ self.kv_dim = self.num_kv_heads * self.head_dim
289
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
290
+ self.dropout = float(config.attention_dropout)
291
+ self.attention_backend = config.attention_backend
292
+ self.softmax_scale = self.head_dim ** -0.5
293
+ self.debug_attention_backend = config.debug_attention_backend
294
+ self._debug_backend_printed = False
295
+ self.perf_counters: dict[str, int] | None = None
296
+
297
+ self.qkv_proj = MetisLinear(
298
+ self.hidden_size,
299
+ self.q_dim + (2 * self.kv_dim),
300
+ bias=config.attention_bias,
301
+ use_fp8=use_fp8,
302
+ )
303
+ self.o_proj = MetisLinear(
304
+ self.q_dim,
305
+ self.hidden_size,
306
+ bias=config.attention_bias,
307
+ use_fp8=use_fp8,
308
+ )
309
+ self.rotary_emb = MetisRotaryEmbedding(self.head_dim, base=config.rope_theta)
310
+
311
+ def _should_use_flash_attention_3(
312
+ self,
313
+ query_states: torch.Tensor,
314
+ key_states: torch.Tensor,
315
+ value_states: torch.Tensor,
316
+ attention_mask: torch.Tensor | None,
317
+ ) -> bool:
318
+ if self.attention_backend not in {"auto", "flash_attention_3"}:
319
+ return False
320
+ if attention_mask is not None:
321
+ return False
322
+ if query_states.device.type != "cuda":
323
+ return False
324
+ if query_states.dtype not in {torch.float16, torch.bfloat16}:
325
+ return False
326
+ if not (_is_hopper_device(query_states.device) and _is_hopper_device(key_states.device)):
327
+ return False
328
+ if query_states.shape[-1] > 256 or key_states.shape[-1] > 256:
329
+ return False
330
+ return value_states.dtype == query_states.dtype
331
+
332
+ def _bump_counter(self, name: str, amount: int = 1) -> None:
333
+ if self.perf_counters is not None:
334
+ self.perf_counters[name] = self.perf_counters.get(name, 0) + amount
335
+
336
+ def _debug_backend_once(self, backend: str, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> None:
337
+ if not self.debug_attention_backend or self._debug_backend_printed:
338
+ return
339
+ print(
340
+ f"MetisSelfAttention backend={backend} shape={tuple(hidden_states.shape)} "
341
+ f"mask={'none' if attention_mask is None else tuple(attention_mask.shape)}",
342
+ flush=True,
343
+ )
344
+ self._debug_backend_printed = True
345
+
346
+ def _flash_attention_3(
347
+ self,
348
+ query_states: torch.Tensor,
349
+ key_states: torch.Tensor,
350
+ value_states: torch.Tensor,
351
+ ) -> torch.Tensor:
352
+ flash_attn_func = _load_flash_attention_3()
353
+ q = query_states.transpose(1, 2).contiguous()
354
+ k = key_states.transpose(1, 2).contiguous()
355
+ v = value_states.transpose(1, 2).contiguous()
356
+ kwargs: dict[str, Any] = {
357
+ "softmax_scale": self.softmax_scale,
358
+ "causal": True,
359
+ }
360
+ if _flash_attention_3_accepts_dropout_p():
361
+ kwargs["dropout_p"] = self.dropout if self.training else 0.0
362
+ elif self.training and self.dropout:
363
+ raise RuntimeError("This FlashAttention-3 build does not expose dropout_p; set attention_dropout=0.0.")
364
+ attn_output = flash_attn_func(q, k, v, **kwargs)
365
+ return attn_output.transpose(1, 2)
366
+
367
+ def forward(
368
+ self,
369
+ hidden_states: torch.Tensor,
370
+ *,
371
+ attention_mask: torch.Tensor | None,
372
+ position_ids: torch.Tensor,
373
+ is_first_microbatch: bool | None = None,
374
+ ) -> torch.Tensor:
375
+ batch_size, seq_len, _ = hidden_states.shape
376
+ qkv_states = self.qkv_proj(hidden_states, is_first_microbatch=is_first_microbatch)
377
+ query_states, key_states, value_states = qkv_states.split((self.q_dim, self.kv_dim, self.kv_dim), dim=-1)
378
+ query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
379
+ key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
380
+ value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
381
+
382
+ cos, sin = self.rotary_emb(position_ids, dtype=hidden_states.dtype)
383
+ query_states = apply_rotary_pos_emb(query_states, cos, sin)
384
+ key_states = apply_rotary_pos_emb(key_states, cos, sin)
385
+ key_states = repeat_kv(key_states, self.num_kv_groups)
386
+ value_states = repeat_kv(value_states, self.num_kv_groups)
387
+
388
+ if attention_mask is not None:
389
+ self._bump_counter("attention_mask_passed_calls")
390
+
391
+ if self._should_use_flash_attention_3(query_states, key_states, value_states, attention_mask):
392
+ self._bump_counter("fa3_calls")
393
+ self._debug_backend_once("flash_attention_3", hidden_states, attention_mask)
394
+ attn_output = self._flash_attention_3(query_states, key_states, value_states)
395
+ elif hasattr(F, "scaled_dot_product_attention"):
396
+ self._bump_counter("sdpa_calls")
397
+ self._debug_backend_once("sdpa", hidden_states, attention_mask)
398
+ attn_mask = None
399
+ if attention_mask is not None:
400
+ attn_mask = build_causal_mask(
401
+ batch_size=batch_size,
402
+ seq_len=seq_len,
403
+ device=hidden_states.device,
404
+ dtype=query_states.dtype,
405
+ attention_mask=attention_mask,
406
+ )
407
+ attn_output = F.scaled_dot_product_attention(
408
+ query_states,
409
+ key_states,
410
+ value_states,
411
+ attn_mask=attn_mask,
412
+ dropout_p=self.dropout if self.training else 0.0,
413
+ is_causal=attention_mask is None,
414
+ scale=self.softmax_scale,
415
+ )
416
+ else:
417
+ self._bump_counter("eager_attention_calls")
418
+ self._debug_backend_once("eager", hidden_states, attention_mask)
419
+ scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
420
+ if attention_mask is not None:
421
+ scores = scores + build_causal_mask(
422
+ batch_size=batch_size,
423
+ seq_len=seq_len,
424
+ device=hidden_states.device,
425
+ dtype=scores.dtype,
426
+ attention_mask=attention_mask,
427
+ )
428
+ else:
429
+ scores = scores + torch.triu(
430
+ torch.full(
431
+ (seq_len, seq_len),
432
+ torch.finfo(scores.dtype).min,
433
+ device=scores.device,
434
+ dtype=scores.dtype,
435
+ ),
436
+ diagonal=1,
437
+ ).unsqueeze(0).unsqueeze(0)
438
+ weights = torch.softmax(scores.float(), dim=-1).to(dtype=query_states.dtype)
439
+ if self.dropout > 0.0 and self.training:
440
+ weights = F.dropout(weights, p=self.dropout, training=True)
441
+ attn_output = torch.matmul(weights, value_states)
442
+
443
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)
444
+ return self.o_proj(attn_output, is_first_microbatch=is_first_microbatch)
445
+
446
+
447
+ class MetisSwiGLU(nn.Module):
448
+ def __init__(self, config: MetisMambaConfig, *, use_fp8: bool) -> None:
449
+ super().__init__()
450
+ self.intermediate_size = config.intermediate_size
451
+ self.gate_up_proj = MetisLinear(
452
+ config.d_model,
453
+ 2 * config.intermediate_size,
454
+ bias=config.mlp_bias,
455
+ use_fp8=use_fp8,
456
+ )
457
+ self.down_proj = MetisLinear(
458
+ config.intermediate_size,
459
+ config.d_model,
460
+ bias=config.mlp_bias,
461
+ use_fp8=use_fp8,
462
+ )
463
+
464
+ def forward(self, hidden_states: torch.Tensor, *, is_first_microbatch: bool | None = None) -> torch.Tensor:
465
+ gate_up = self.gate_up_proj(hidden_states, is_first_microbatch=is_first_microbatch)
466
+ gate, up = gate_up.split(self.intermediate_size, dim=-1)
467
+ gated = F.silu(gate)
468
+ return self.down_proj(gated * up, is_first_microbatch=is_first_microbatch)
469
+
470
+
471
+ class MetisTELayerNormSwiGLU(nn.Module):
472
+ def __init__(self, config: MetisMambaConfig, *, use_fp8: bool) -> None:
473
+ super().__init__()
474
+ self.impl = build_layernorm_mlp(
475
+ hidden_size=config.d_model,
476
+ ffn_hidden_size=config.intermediate_size,
477
+ eps=1e-6,
478
+ bias=config.mlp_bias,
479
+ use_fp8=use_fp8,
480
+ )
481
+ if self.impl is None:
482
+ raise RuntimeError("TE fused MLP requires Transformer Engine FP8 modules.")
483
+
484
+ def forward(self, hidden_states: torch.Tensor, *, is_first_microbatch: bool | None = None) -> torch.Tensor:
485
+ return self.impl(hidden_states, is_first_microbatch=is_first_microbatch)
486
+
487
+
488
+ class MetisTransformerBlock(nn.Module):
489
+ def __init__(self, config: MetisMambaConfig, *, use_fp8: bool) -> None:
490
+ super().__init__()
491
+ self.attn_norm = build_rms_norm_module(config.d_model, eps=1e-6, use_fp8=use_fp8)
492
+ self.self_attn = MetisSelfAttention(config, use_fp8=use_fp8)
493
+ self.use_te_fused_mlp = bool(config.te_fused_mlp and use_fp8)
494
+ if self.use_te_fused_mlp:
495
+ self.mlp = MetisTELayerNormSwiGLU(config, use_fp8=use_fp8)
496
+ self.ffn_norm = None
497
+ else:
498
+ self.ffn_norm = build_rms_norm_module(config.d_model, eps=1e-6, use_fp8=use_fp8)
499
+ self.mlp = MetisSwiGLU(config, use_fp8=use_fp8)
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ *,
505
+ attention_mask: torch.Tensor | None,
506
+ position_ids: torch.Tensor,
507
+ is_first_microbatch: bool | None = None,
508
+ ) -> torch.Tensor:
509
+ hidden_states = hidden_states + self.self_attn(
510
+ self.attn_norm(hidden_states),
511
+ attention_mask=attention_mask,
512
+ position_ids=position_ids,
513
+ is_first_microbatch=is_first_microbatch,
514
+ )
515
+ if self.use_te_fused_mlp:
516
+ hidden_states = hidden_states + self.mlp(hidden_states, is_first_microbatch=is_first_microbatch)
517
+ else:
518
+ hidden_states = hidden_states + self.mlp(
519
+ self.ffn_norm(hidden_states),
520
+ is_first_microbatch=is_first_microbatch,
521
+ )
522
+ return hidden_states
523
+
524
+
525
+ class MetisMoRModel(nn.Module):
526
+ COUNTER_KEYS = (
527
+ "static_dense_forward_calls",
528
+ "dynamic_token_mor_forward_calls",
529
+ "static_sequence_mor_forward_calls",
530
+ "static_block_mor_forward_calls",
531
+ "router_calls",
532
+ "pack_active_tokens_calls",
533
+ "scatter_active_tokens_calls",
534
+ "attention_mask_passed_calls",
535
+ "fa3_calls",
536
+ "sdpa_calls",
537
+ "eager_attention_calls",
538
+ "te_attention_calls",
539
+ )
540
+
541
+ def __init__(
542
+ self,
543
+ config: MetisMambaConfig,
544
+ *,
545
+ use_fp8: bool = False,
546
+ fp8_recipe=None,
547
+ fp8_group=None,
548
+ ) -> None:
549
+ super().__init__()
550
+ self.config = config
551
+ self.use_fp8 = use_fp8
552
+ self.fp8_recipe = fp8_recipe
553
+ self.fp8_group = fp8_group
554
+ self.training_routing_mode = os.environ.get("METIS_MOR_TRAIN_ROUTING_MODE", "token_pack").strip().lower()
555
+ if self.training_routing_mode not in {"token_pack", "dense_active"}:
556
+ raise ValueError(
557
+ "METIS_MOR_TRAIN_ROUTING_MODE must be one of: token_pack, dense_active."
558
+ )
559
+ self.embed_tokens = nn.Embedding(config.padded_vocab_size, config.d_model)
560
+ self.layers = nn.ModuleList([MetisTransformerBlock(config, use_fp8=use_fp8) for _ in range(config.n_layer)])
561
+ self.final_norm = build_rms_norm_module(config.d_model, eps=1e-6, use_fp8=use_fp8)
562
+ self.router_norm: nn.Module | None = None
563
+ self.router_up: MetisLinear | None = None
564
+ self.router_out: nn.Linear | None = None
565
+ self.sequence_router_norm: nn.Module | None = None
566
+ self.sequence_router_up: MetisLinear | None = None
567
+ self.sequence_router_out: nn.Linear | None = None
568
+ self.block_router_norm: nn.Module | None = None
569
+ self.block_router_up: MetisLinear | None = None
570
+ self.block_router_out: nn.Linear | None = None
571
+
572
+ if config.uses_dynamic_token_mor:
573
+ self.router_norm = build_rms_norm_module(config.d_model, eps=1e-6, use_fp8=use_fp8)
574
+ self.router_up = MetisLinear(
575
+ config.d_model,
576
+ config.mor_router_hidden_dim,
577
+ bias=True,
578
+ use_fp8=use_fp8,
579
+ )
580
+ self.router_out = nn.Linear(config.mor_router_hidden_dim, config.mor_max_depth)
581
+ elif config.uses_static_sequence_mor:
582
+ self.sequence_router_norm = build_rms_norm_module(config.d_model, eps=1e-6, use_fp8=use_fp8)
583
+ self.sequence_router_up = MetisLinear(
584
+ config.d_model,
585
+ config.mor_router_hidden_dim,
586
+ bias=True,
587
+ use_fp8=use_fp8,
588
+ )
589
+ self.sequence_router_out = nn.Linear(config.mor_router_hidden_dim, config.mor_max_depth)
590
+ elif config.uses_static_block_mor:
591
+ self.block_router_norm = build_rms_norm_module(config.d_model, eps=1e-6, use_fp8=use_fp8)
592
+ self.block_router_up = MetisLinear(
593
+ config.d_model,
594
+ config.mor_router_hidden_dim,
595
+ bias=True,
596
+ use_fp8=use_fp8,
597
+ )
598
+ self.block_router_out = nn.Linear(config.mor_router_hidden_dim, config.mor_max_depth)
599
+
600
+ self.perf_counters: dict[str, int] = {}
601
+ self.reset_perf_counters()
602
+
603
+ def reset_perf_counters(self) -> None:
604
+ self.perf_counters = {key: 0 for key in self.COUNTER_KEYS}
605
+ for layer in self.layers:
606
+ layer.self_attn.perf_counters = self.perf_counters
607
+
608
+ def get_perf_counters(self) -> dict[str, int]:
609
+ return dict(self.perf_counters)
610
+
611
+ def _bump_counter(self, name: str, amount: int = 1) -> None:
612
+ self.perf_counters[name] = self.perf_counters.get(name, 0) + amount
613
+
614
+ def _build_position_ids(self, input_ids: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor:
615
+ if position_ids is not None:
616
+ return position_ids
617
+ batch_size, seq_len = input_ids.shape
618
+ return torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
619
+
620
+ def _static_attention_mask(self, attention_mask: torch.Tensor | None) -> torch.Tensor | None:
621
+ if self.config.attention_mask_mode == "causal_none":
622
+ return None
623
+ return attention_mask
624
+
625
+ def _zero_aux(self, hidden_states: torch.Tensor) -> torch.Tensor:
626
+ return hidden_states.new_zeros(())
627
+
628
+ def _run_shared_stack(
629
+ self,
630
+ hidden_states: torch.Tensor,
631
+ *,
632
+ attention_mask: torch.Tensor | None,
633
+ position_ids: torch.Tensor,
634
+ is_first_microbatch: bool | None = None,
635
+ ) -> torch.Tensor:
636
+ with fp8_autocast_context(
637
+ enabled=self.use_fp8,
638
+ recipe=self.fp8_recipe,
639
+ fp8_group=self.fp8_group,
640
+ ):
641
+ for layer in self.layers:
642
+ hidden_states = layer(
643
+ hidden_states,
644
+ attention_mask=attention_mask,
645
+ position_ids=position_ids,
646
+ is_first_microbatch=is_first_microbatch,
647
+ )
648
+ return hidden_states
649
+
650
+ def _run_router_mlp(
651
+ self,
652
+ norm: nn.Module,
653
+ up: MetisLinear,
654
+ out: nn.Linear,
655
+ hidden_states: torch.Tensor,
656
+ *,
657
+ is_first_microbatch: bool | None = None,
658
+ ) -> torch.Tensor:
659
+ original_leading_shape = hidden_states.shape[:-1]
660
+ hidden_size = hidden_states.shape[-1]
661
+ router_input = hidden_states.reshape(-1, hidden_size)
662
+ original_rows = router_input.shape[0]
663
+ pad_rows = 0
664
+ if self.use_fp8 and original_rows % 8 != 0:
665
+ pad_rows = 8 - (original_rows % 8)
666
+ router_input = F.pad(router_input, (0, 0, 0, pad_rows))
667
+ with fp8_autocast_context(
668
+ enabled=self.use_fp8,
669
+ recipe=self.fp8_recipe,
670
+ fp8_group=self.fp8_group,
671
+ ):
672
+ router_hidden = F.silu(
673
+ up(
674
+ norm(router_input),
675
+ is_first_microbatch=is_first_microbatch,
676
+ )
677
+ )
678
+ router_logits = out(router_hidden)
679
+ if pad_rows:
680
+ router_logits = router_logits[:original_rows]
681
+ return router_logits.reshape(*original_leading_shape, router_logits.shape[-1])
682
+
683
+ def _route_tokens(
684
+ self,
685
+ hidden_states: torch.Tensor,
686
+ *,
687
+ attention_mask: torch.Tensor | None,
688
+ is_first_microbatch: bool | None = None,
689
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
690
+ if self.router_norm is None or self.router_up is None or self.router_out is None:
691
+ raise RuntimeError("Dynamic token MoR was requested, but the token router is disabled in this config.")
692
+ self._bump_counter("router_calls")
693
+ router_logits = self._run_router_mlp(
694
+ self.router_norm,
695
+ self.router_up,
696
+ self.router_out,
697
+ hidden_states,
698
+ is_first_microbatch=is_first_microbatch,
699
+ )
700
+ soft_route_probs = torch.softmax(router_logits / self.config.mor_router_temperature, dim=-1)
701
+ if self.training:
702
+ hard_route_probs = F.gumbel_softmax(
703
+ router_logits,
704
+ tau=self.config.mor_router_temperature,
705
+ hard=True,
706
+ dim=-1,
707
+ )
708
+ else:
709
+ hard_route_probs = torch.zeros_like(soft_route_probs)
710
+ hard_route_probs.scatter_(
711
+ -1,
712
+ torch.argmax(soft_route_probs, dim=-1, keepdim=True),
713
+ 1.0,
714
+ )
715
+ chosen_depths = torch.argmax(hard_route_probs, dim=-1) + 1
716
+ if attention_mask is not None:
717
+ chosen_depths = chosen_depths * attention_mask.to(torch.long)
718
+ depth_values = torch.arange(
719
+ 1,
720
+ self.config.mor_max_depth + 1,
721
+ device=hidden_states.device,
722
+ dtype=soft_route_probs.dtype,
723
+ )
724
+ expected_depth = torch.sum(soft_route_probs * depth_values, dim=-1)
725
+ if attention_mask is not None:
726
+ valid_tokens = attention_mask.to(torch.bool)
727
+ mean_depth = expected_depth.masked_select(valid_tokens).mean()
728
+ else:
729
+ mean_depth = expected_depth.mean()
730
+ route_aux_loss = (mean_depth - self.config.mor_target_avg_depth).pow(2)
731
+ return soft_route_probs, hard_route_probs, chosen_depths, route_aux_loss
732
+
733
+ def _default_static_capacities(self, item_count: int, configured_depth2: int, configured_depth3: int) -> tuple[int, int]:
734
+ if configured_depth2 > 0 or configured_depth3 > 0:
735
+ depth2 = min(item_count, configured_depth2)
736
+ depth3 = min(depth2, configured_depth3)
737
+ return depth2, depth3
738
+ extra_items = max(0, int(round(item_count * max(self.config.mor_target_avg_depth - 1.0, 0.0))))
739
+ depth3 = min(item_count, int(round(extra_items * 0.375)))
740
+ depth2 = min(item_count, max(0, extra_items - depth3))
741
+ depth3 = min(depth2, depth3)
742
+ return depth2, depth3
743
+
744
+ def _topk_indices(self, scores: torch.Tensor, k: int) -> torch.Tensor:
745
+ if k <= 0:
746
+ return torch.empty(0, device=scores.device, dtype=torch.long)
747
+ return torch.topk(scores, k=min(k, scores.numel()), sorted=False).indices
748
+
749
+ def _static_route_aux_loss(self, hidden_states: torch.Tensor, mean_depth: torch.Tensor) -> torch.Tensor:
750
+ if not self.config.mor_train_router:
751
+ return self._zero_aux(hidden_states)
752
+ return (mean_depth - self.config.mor_target_avg_depth).pow(2)
753
+
754
+ def _forward_static_dense(
755
+ self,
756
+ input_ids: torch.Tensor,
757
+ *,
758
+ attention_mask: torch.Tensor | None,
759
+ position_ids: torch.Tensor | None,
760
+ is_first_microbatch: bool | None,
761
+ ) -> dict[str, Any]:
762
+ self._bump_counter("static_dense_forward_calls")
763
+ hidden_states = self.embed_tokens(input_ids)
764
+ position_ids = self._build_position_ids(input_ids, position_ids)
765
+ current_hidden = self._run_shared_stack(
766
+ hidden_states,
767
+ attention_mask=None,
768
+ position_ids=position_ids,
769
+ is_first_microbatch=is_first_microbatch,
770
+ )
771
+ final_hidden = self.final_norm(current_hidden)
772
+ one = hidden_states.new_ones(())
773
+ return {
774
+ "hidden_states": [final_hidden],
775
+ "final_hidden": final_hidden,
776
+ "route_probs": None,
777
+ "chosen_depths": None,
778
+ "route_aux_loss": self._zero_aux(hidden_states),
779
+ "mean_depth": one,
780
+ "active_token_ratios": hidden_states.new_ones(1),
781
+ }
782
+
783
+ def _forward_static_sequence_mor(
784
+ self,
785
+ input_ids: torch.Tensor,
786
+ *,
787
+ attention_mask: torch.Tensor | None,
788
+ position_ids: torch.Tensor | None,
789
+ is_first_microbatch: bool | None,
790
+ ) -> dict[str, Any]:
791
+ if self.sequence_router_norm is None or self.sequence_router_up is None or self.sequence_router_out is None:
792
+ raise RuntimeError("Static sequence MoR was requested, but the sequence router is disabled in this config.")
793
+ self._bump_counter("static_sequence_mor_forward_calls")
794
+ hidden_states = self.embed_tokens(input_ids)
795
+ batch_size, _seq_len, _hidden_size = hidden_states.shape
796
+ position_ids = self._build_position_ids(input_ids, position_ids)
797
+ stack_attention_mask = self._static_attention_mask(attention_mask)
798
+ current_hidden = self._run_shared_stack(
799
+ hidden_states,
800
+ attention_mask=stack_attention_mask,
801
+ position_ids=position_ids,
802
+ is_first_microbatch=is_first_microbatch,
803
+ )
804
+ step_hidden_states = [self.final_norm(current_hidden)]
805
+
806
+ router_repr = current_hidden.detach().mean(dim=1)
807
+ self._bump_counter("router_calls")
808
+ router_logits = self._run_router_mlp(
809
+ self.sequence_router_norm,
810
+ self.sequence_router_up,
811
+ self.sequence_router_out,
812
+ router_repr,
813
+ is_first_microbatch=is_first_microbatch,
814
+ )
815
+ route_probs = torch.softmax(router_logits / self.config.mor_router_temperature, dim=-1)
816
+ depth2_cap, depth3_cap = self._default_static_capacities(
817
+ batch_size,
818
+ self.config.mor_depth2_capacity_sequences,
819
+ self.config.mor_depth3_capacity_sequences,
820
+ )
821
+
822
+ chosen_depths = torch.ones(batch_size, device=input_ids.device, dtype=torch.long)
823
+ if self.config.mor_max_depth >= 2 and depth2_cap > 0:
824
+ depth2_scores = router_logits[:, 1:].amax(dim=-1)
825
+ depth2_indices = self._topk_indices(depth2_scores, depth2_cap)
826
+ selected_hidden = current_hidden.index_select(0, depth2_indices)
827
+ selected_position_ids = position_ids.index_select(0, depth2_indices)
828
+ updated_hidden = self._run_shared_stack(
829
+ selected_hidden,
830
+ attention_mask=None,
831
+ position_ids=selected_position_ids,
832
+ is_first_microbatch=is_first_microbatch,
833
+ )
834
+ current_hidden = current_hidden.index_copy(0, depth2_indices, updated_hidden)
835
+ chosen_depths = chosen_depths.index_fill(0, depth2_indices, 2)
836
+ step_hidden_states.append(self.final_norm(current_hidden))
837
+
838
+ if self.config.mor_max_depth >= 3 and depth3_cap > 0:
839
+ depth3_scores = router_logits.index_select(0, depth2_indices)[:, 2]
840
+ local_depth3 = self._topk_indices(depth3_scores, min(depth3_cap, depth2_indices.numel()))
841
+ depth3_indices = depth2_indices.index_select(0, local_depth3)
842
+ selected_hidden = current_hidden.index_select(0, depth3_indices)
843
+ selected_position_ids = position_ids.index_select(0, depth3_indices)
844
+ updated_hidden = self._run_shared_stack(
845
+ selected_hidden,
846
+ attention_mask=None,
847
+ position_ids=selected_position_ids,
848
+ is_first_microbatch=is_first_microbatch,
849
+ )
850
+ current_hidden = current_hidden.index_copy(0, depth3_indices, updated_hidden)
851
+ chosen_depths = chosen_depths.index_fill(0, depth3_indices, 3)
852
+ step_hidden_states.append(self.final_norm(current_hidden))
853
+
854
+ final_hidden = self.final_norm(current_hidden)
855
+ mean_depth = hidden_states.new_tensor(1.0 + (float(depth2_cap) / batch_size) + (float(depth3_cap) / batch_size))
856
+ active_ratios = hidden_states.new_tensor([1.0, float(depth2_cap) / batch_size, float(depth3_cap) / batch_size])
857
+ return {
858
+ "hidden_states": step_hidden_states,
859
+ "final_hidden": final_hidden,
860
+ "route_probs": route_probs,
861
+ "chosen_depths": chosen_depths,
862
+ "route_aux_loss": self._static_route_aux_loss(hidden_states, mean_depth),
863
+ "mean_depth": mean_depth,
864
+ "active_token_ratios": active_ratios,
865
+ }
866
+
867
+ def _forward_static_block_mor(
868
+ self,
869
+ input_ids: torch.Tensor,
870
+ *,
871
+ attention_mask: torch.Tensor | None,
872
+ position_ids: torch.Tensor | None,
873
+ is_first_microbatch: bool | None,
874
+ ) -> dict[str, Any]:
875
+ if self.block_router_norm is None or self.block_router_up is None or self.block_router_out is None:
876
+ raise RuntimeError("Static block MoR was requested, but the block router is disabled in this config.")
877
+ self._bump_counter("static_block_mor_forward_calls")
878
+ hidden_states = self.embed_tokens(input_ids)
879
+ batch_size, seq_len, hidden_size = hidden_states.shape
880
+ block_size = self.config.mor_block_size
881
+ if seq_len % block_size != 0:
882
+ raise ValueError(f"static_block_mor requires seq_len divisible by mor_block_size ({seq_len} % {block_size}).")
883
+ num_blocks = seq_len // block_size
884
+ total_blocks = batch_size * num_blocks
885
+ position_ids = self._build_position_ids(input_ids, position_ids)
886
+ stack_attention_mask = self._static_attention_mask(attention_mask)
887
+ current_hidden = self._run_shared_stack(
888
+ hidden_states,
889
+ attention_mask=stack_attention_mask,
890
+ position_ids=position_ids,
891
+ is_first_microbatch=is_first_microbatch,
892
+ )
893
+ step_hidden_states = [self.final_norm(current_hidden)]
894
+
895
+ blocks = current_hidden.detach().contiguous().view(batch_size, num_blocks, block_size, hidden_size)
896
+ router_repr = blocks.mean(dim=2).reshape(total_blocks, hidden_size)
897
+ self._bump_counter("router_calls")
898
+ router_logits = self._run_router_mlp(
899
+ self.block_router_norm,
900
+ self.block_router_up,
901
+ self.block_router_out,
902
+ router_repr,
903
+ is_first_microbatch=is_first_microbatch,
904
+ )
905
+ route_probs = torch.softmax(router_logits / self.config.mor_router_temperature, dim=-1).view(
906
+ batch_size,
907
+ num_blocks,
908
+ self.config.mor_max_depth,
909
+ )
910
+ depth2_cap, depth3_cap = self._default_static_capacities(
911
+ total_blocks,
912
+ self.config.mor_depth2_capacity_blocks,
913
+ self.config.mor_depth3_capacity_blocks,
914
+ )
915
+
916
+ flat_hidden = current_hidden.contiguous().view(batch_size, num_blocks, block_size, hidden_size).reshape(
917
+ total_blocks,
918
+ block_size,
919
+ hidden_size,
920
+ )
921
+ flat_position_ids = position_ids.contiguous().view(batch_size, num_blocks, block_size).reshape(total_blocks, block_size)
922
+ chosen_depths = torch.ones(total_blocks, device=input_ids.device, dtype=torch.long)
923
+
924
+ if self.config.mor_max_depth >= 2 and depth2_cap > 0:
925
+ depth2_scores = router_logits[:, 1:].amax(dim=-1)
926
+ depth2_indices = self._topk_indices(depth2_scores, depth2_cap)
927
+ selected_hidden = flat_hidden.index_select(0, depth2_indices)
928
+ selected_position_ids = flat_position_ids.index_select(0, depth2_indices)
929
+ updated_hidden = self._run_shared_stack(
930
+ selected_hidden,
931
+ attention_mask=None,
932
+ position_ids=selected_position_ids,
933
+ is_first_microbatch=is_first_microbatch,
934
+ )
935
+ flat_hidden = flat_hidden.index_copy(0, depth2_indices, updated_hidden)
936
+ chosen_depths = chosen_depths.index_fill(0, depth2_indices, 2)
937
+ current_hidden = flat_hidden.view(batch_size, num_blocks, block_size, hidden_size).reshape(
938
+ batch_size,
939
+ seq_len,
940
+ hidden_size,
941
+ )
942
+ step_hidden_states.append(self.final_norm(current_hidden))
943
+
944
+ if self.config.mor_max_depth >= 3 and depth3_cap > 0:
945
+ depth3_scores = router_logits.index_select(0, depth2_indices)[:, 2]
946
+ local_depth3 = self._topk_indices(depth3_scores, min(depth3_cap, depth2_indices.numel()))
947
+ depth3_indices = depth2_indices.index_select(0, local_depth3)
948
+ selected_hidden = flat_hidden.index_select(0, depth3_indices)
949
+ selected_position_ids = flat_position_ids.index_select(0, depth3_indices)
950
+ updated_hidden = self._run_shared_stack(
951
+ selected_hidden,
952
+ attention_mask=None,
953
+ position_ids=selected_position_ids,
954
+ is_first_microbatch=is_first_microbatch,
955
+ )
956
+ flat_hidden = flat_hidden.index_copy(0, depth3_indices, updated_hidden)
957
+ chosen_depths = chosen_depths.index_fill(0, depth3_indices, 3)
958
+ current_hidden = flat_hidden.view(batch_size, num_blocks, block_size, hidden_size).reshape(
959
+ batch_size,
960
+ seq_len,
961
+ hidden_size,
962
+ )
963
+ step_hidden_states.append(self.final_norm(current_hidden))
964
+
965
+ final_hidden = self.final_norm(
966
+ flat_hidden.view(batch_size, num_blocks, block_size, hidden_size).reshape(batch_size, seq_len, hidden_size)
967
+ )
968
+ mean_depth = hidden_states.new_tensor(
969
+ 1.0 + (float(depth2_cap) / total_blocks) + (float(depth3_cap) / total_blocks)
970
+ )
971
+ active_ratios = hidden_states.new_tensor(
972
+ [1.0, float(depth2_cap) / total_blocks, float(depth3_cap) / total_blocks]
973
+ )
974
+ return {
975
+ "hidden_states": step_hidden_states,
976
+ "final_hidden": final_hidden,
977
+ "route_probs": route_probs,
978
+ "chosen_depths": chosen_depths.view(batch_size, num_blocks),
979
+ "route_aux_loss": self._static_route_aux_loss(hidden_states, mean_depth),
980
+ "mean_depth": mean_depth,
981
+ "active_token_ratios": active_ratios,
982
+ }
983
+
984
+ def _forward_dynamic_token_mor(
985
+ self,
986
+ input_ids: torch.Tensor,
987
+ *,
988
+ attention_mask: torch.Tensor | None,
989
+ position_ids: torch.Tensor | None,
990
+ is_first_microbatch: bool | None,
991
+ ) -> dict[str, Any]:
992
+ self._bump_counter("dynamic_token_mor_forward_calls")
993
+ hidden_states = self.embed_tokens(input_ids)
994
+ position_ids = self._build_position_ids(input_ids, position_ids)
995
+ valid_tokens = attention_mask.to(torch.bool) if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
996
+ soft_route_probs, _hard_route_probs, chosen_depths, route_aux_loss = self._route_tokens(
997
+ hidden_states,
998
+ attention_mask=attention_mask,
999
+ is_first_microbatch=is_first_microbatch,
1000
+ )
1001
+
1002
+ step_hidden_states: list[torch.Tensor] = []
1003
+ active_ratios: list[torch.Tensor] = []
1004
+ current_hidden = hidden_states
1005
+ valid_count = valid_tokens.sum().clamp_min(1)
1006
+ use_dense_active = (
1007
+ self.training
1008
+ and (
1009
+ self.training_routing_mode == "dense_active"
1010
+ or self.config.disable_token_packing
1011
+ or self.config.disable_token_scatter
1012
+ )
1013
+ )
1014
+ for step_index in range(1, self.config.mor_max_depth + 1):
1015
+ active_mask = (chosen_depths >= step_index) & valid_tokens
1016
+ active_ratios.append(active_mask.sum().to(hidden_states.dtype) / valid_count.to(hidden_states.dtype))
1017
+ if use_dense_active:
1018
+ updated_hidden = self._run_shared_stack(
1019
+ current_hidden,
1020
+ attention_mask=attention_mask,
1021
+ position_ids=position_ids,
1022
+ is_first_microbatch=is_first_microbatch,
1023
+ )
1024
+ current_hidden = torch.where(active_mask.unsqueeze(-1), updated_hidden, current_hidden)
1025
+ elif bool(active_mask.all().item()):
1026
+ current_hidden = self._run_shared_stack(
1027
+ current_hidden,
1028
+ attention_mask=attention_mask,
1029
+ position_ids=position_ids,
1030
+ is_first_microbatch=is_first_microbatch,
1031
+ )
1032
+ else:
1033
+ self._bump_counter("pack_active_tokens_calls")
1034
+ packed = pack_active_tokens(
1035
+ current_hidden,
1036
+ position_ids,
1037
+ active_mask,
1038
+ pad_multiple=self.config.fp8_pad_multiple if self.use_fp8 else 1,
1039
+ )
1040
+ if packed is not None:
1041
+ packed_hidden, packed_positions, packed_mask, packed_indices = packed
1042
+ updated_hidden = self._run_shared_stack(
1043
+ packed_hidden,
1044
+ attention_mask=packed_mask,
1045
+ position_ids=packed_positions,
1046
+ is_first_microbatch=is_first_microbatch,
1047
+ )
1048
+ if self.config.disable_token_scatter:
1049
+ current_hidden = updated_hidden
1050
+ else:
1051
+ self._bump_counter("scatter_active_tokens_calls")
1052
+ current_hidden = scatter_active_tokens(
1053
+ current_hidden,
1054
+ updated_hidden,
1055
+ packed_mask,
1056
+ packed_indices,
1057
+ active_mask,
1058
+ )
1059
+ step_hidden_states.append(self.final_norm(current_hidden))
1060
+
1061
+ if self.config.disable_depth_stack:
1062
+ final_hidden = step_hidden_states[-1]
1063
+ else:
1064
+ stacked_hidden = torch.stack(step_hidden_states, dim=2)
1065
+ if self.training:
1066
+ final_hidden = torch.sum(
1067
+ stacked_hidden * soft_route_probs.unsqueeze(-1).to(stacked_hidden.dtype),
1068
+ dim=2,
1069
+ )
1070
+ else:
1071
+ gather_index = (chosen_depths.clamp_min(1) - 1).unsqueeze(-1).unsqueeze(-1).expand(
1072
+ -1,
1073
+ -1,
1074
+ 1,
1075
+ stacked_hidden.size(-1),
1076
+ )
1077
+ final_hidden = torch.gather(stacked_hidden, 2, gather_index).squeeze(2)
1078
+
1079
+ depth_values = torch.arange(
1080
+ 1,
1081
+ self.config.mor_max_depth + 1,
1082
+ device=hidden_states.device,
1083
+ dtype=soft_route_probs.dtype,
1084
+ )
1085
+ mean_depth = torch.sum(soft_route_probs * depth_values, dim=-1)
1086
+ if attention_mask is not None:
1087
+ mean_depth = mean_depth.masked_select(valid_tokens).mean()
1088
+ else:
1089
+ mean_depth = mean_depth.mean()
1090
+
1091
+ return {
1092
+ "hidden_states": step_hidden_states,
1093
+ "final_hidden": final_hidden,
1094
+ "route_probs": soft_route_probs,
1095
+ "chosen_depths": chosen_depths,
1096
+ "route_aux_loss": route_aux_loss,
1097
+ "mean_depth": mean_depth,
1098
+ "active_token_ratios": torch.stack(active_ratios),
1099
+ }
1100
+
1101
+ def forward(
1102
+ self,
1103
+ input_ids: torch.Tensor,
1104
+ *,
1105
+ attention_mask: torch.Tensor | None = None,
1106
+ position_ids: torch.Tensor | None = None,
1107
+ is_first_microbatch: bool | None = None,
1108
+ ) -> dict[str, Any]:
1109
+ if self.config.training_mode == "static_dense_pretrain" or not self.config.mor_enabled:
1110
+ return self._forward_static_dense(
1111
+ input_ids,
1112
+ attention_mask=attention_mask,
1113
+ position_ids=position_ids,
1114
+ is_first_microbatch=is_first_microbatch,
1115
+ )
1116
+ if self.config.training_mode == "static_sequence_mor":
1117
+ return self._forward_static_sequence_mor(
1118
+ input_ids,
1119
+ attention_mask=attention_mask,
1120
+ position_ids=position_ids,
1121
+ is_first_microbatch=is_first_microbatch,
1122
+ )
1123
+ if self.config.training_mode == "static_block_mor":
1124
+ return self._forward_static_block_mor(
1125
+ input_ids,
1126
+ attention_mask=attention_mask,
1127
+ position_ids=position_ids,
1128
+ is_first_microbatch=is_first_microbatch,
1129
+ )
1130
+ return self._forward_dynamic_token_mor(
1131
+ input_ids,
1132
+ attention_mask=attention_mask,
1133
+ position_ids=position_ids,
1134
+ is_first_microbatch=is_first_microbatch,
1135
+ )
1136
+
1137
+
1138
+ class MetisMoRLMHeadModel(nn.Module):
1139
+ def __init__(
1140
+ self,
1141
+ config: MetisMambaConfig,
1142
+ *,
1143
+ use_fp8: bool = False,
1144
+ fp8_recipe=None,
1145
+ fp8_group=None,
1146
+ ) -> None:
1147
+ super().__init__()
1148
+ self.config = config
1149
+ self.use_fp8 = use_fp8
1150
+ self.fp8_recipe = fp8_recipe
1151
+ self.fp8_group = fp8_group
1152
+ self.backbone = MetisMoRModel(
1153
+ config,
1154
+ use_fp8=use_fp8,
1155
+ fp8_recipe=fp8_recipe,
1156
+ fp8_group=fp8_group,
1157
+ )
1158
+ self.lm_head = MetisLinear(
1159
+ config.d_model,
1160
+ config.padded_vocab_size,
1161
+ bias=False,
1162
+ use_fp8=use_fp8,
1163
+ )
1164
+ self.fused_linear_ce = None
1165
+ if config.lm_loss_impl == "liger_fused_linear_ce":
1166
+ self.fused_linear_ce = _load_liger_fused_linear_ce()(ignore_index=-100)
1167
+ self.model_family = config.model_type
1168
+ self.apply(self._init_weights)
1169
+ if config.tie_embeddings:
1170
+ self.tie_weights()
1171
+
1172
+ def reset_perf_counters(self) -> None:
1173
+ self.backbone.reset_perf_counters()
1174
+
1175
+ def get_perf_counters(self) -> dict[str, int]:
1176
+ return self.backbone.get_perf_counters()
1177
+
1178
+ def tie_weights(self) -> None:
1179
+ self.lm_head.impl.weight = self.backbone.embed_tokens.weight
1180
+
1181
+ def _init_weights(self, module: nn.Module) -> None:
1182
+ if isinstance(module, MetisLinear):
1183
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1184
+ if getattr(module.impl, "bias", None) is not None:
1185
+ nn.init.zeros_(module.impl.bias)
1186
+ elif isinstance(module, nn.Linear):
1187
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1188
+ if module.bias is not None:
1189
+ nn.init.zeros_(module.bias)
1190
+ elif isinstance(module, nn.Embedding):
1191
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1192
+
1193
+ def forward(
1194
+ self,
1195
+ input_ids: torch.Tensor,
1196
+ *,
1197
+ labels: torch.Tensor | None = None,
1198
+ attention_mask: torch.Tensor | None = None,
1199
+ position_ids: torch.Tensor | None = None,
1200
+ is_first_microbatch: bool | None = None,
1201
+ return_logits: bool = True,
1202
+ **_: Any,
1203
+ ) -> MetisCausalLMOutput:
1204
+ outputs = self.backbone(
1205
+ input_ids,
1206
+ attention_mask=attention_mask,
1207
+ position_ids=position_ids,
1208
+ is_first_microbatch=is_first_microbatch,
1209
+ )
1210
+ final_hidden = outputs["final_hidden"]
1211
+ logits = None
1212
+ loss = None
1213
+ lm_loss = None
1214
+
1215
+ if labels is not None and self.fused_linear_ce is not None:
1216
+ shift_hidden = final_hidden[:, :-1, :].contiguous().view(-1, final_hidden.size(-1))
1217
+ shift_labels = labels[:, 1:].contiguous().view(-1)
1218
+ with _nvtx_range("fused_linear_ce"):
1219
+ lm_loss = self.fused_linear_ce(self.lm_head.weight, shift_hidden, shift_labels)
1220
+ else:
1221
+ with _nvtx_range("lm_head"):
1222
+ with fp8_autocast_context(
1223
+ enabled=self.use_fp8,
1224
+ recipe=self.fp8_recipe,
1225
+ fp8_group=self.fp8_group,
1226
+ ):
1227
+ logits = self.lm_head(
1228
+ final_hidden,
1229
+ is_first_microbatch=is_first_microbatch,
1230
+ )
1231
+ if labels is not None:
1232
+ with _nvtx_range("cross_entropy"):
1233
+ shift_logits = logits[:, :-1, :].contiguous()
1234
+ shift_labels = labels[:, 1:].contiguous()
1235
+ lm_loss = F.cross_entropy(
1236
+ shift_logits.view(-1, shift_logits.size(-1)),
1237
+ shift_labels.view(-1),
1238
+ ignore_index=-100,
1239
+ )
1240
+
1241
+ if labels is not None:
1242
+ if lm_loss is None:
1243
+ raise RuntimeError("Language-model loss was not computed despite labels being provided.")
1244
+ if return_logits and logits is None:
1245
+ with _nvtx_range("lm_head_for_logits"):
1246
+ with fp8_autocast_context(
1247
+ enabled=self.use_fp8,
1248
+ recipe=self.fp8_recipe,
1249
+ fp8_group=self.fp8_group,
1250
+ ):
1251
+ logits = self.lm_head(
1252
+ final_hidden,
1253
+ is_first_microbatch=is_first_microbatch,
1254
+ )
1255
+ elif not return_logits:
1256
+ logits = None
1257
+ route_aux_loss = outputs.get("route_aux_loss")
1258
+ aux_loss = (
1259
+ route_aux_loss * self.config.mor_router_aux_loss_coef
1260
+ if route_aux_loss is not None
1261
+ else lm_loss.new_zeros(())
1262
+ )
1263
+ loss = lm_loss + aux_loss
1264
+
1265
+ return MetisCausalLMOutput(
1266
+ logits=logits,
1267
+ loss=loss,
1268
+ lm_loss=lm_loss,
1269
+ hidden_states=outputs["hidden_states"],
1270
+ route_probs=outputs["route_probs"],
1271
+ chosen_depths=outputs["chosen_depths"],
1272
+ route_aux_loss=outputs["route_aux_loss"],
1273
+ mean_depth=outputs["mean_depth"],
1274
+ active_token_ratios=outputs["active_token_ratios"],
1275
+ )
1276
+
1277
+
1278
+ class MetisMoRRewardModel(nn.Module):
1279
+ def __init__(
1280
+ self,
1281
+ config: MetisMambaConfig,
1282
+ *,
1283
+ use_fp8: bool = False,
1284
+ fp8_recipe=None,
1285
+ fp8_group=None,
1286
+ ) -> None:
1287
+ super().__init__()
1288
+ self.config = config
1289
+ self.use_fp8 = use_fp8
1290
+ self.fp8_recipe = fp8_recipe
1291
+ self.fp8_group = fp8_group
1292
+ self.backbone = MetisMoRModel(
1293
+ config,
1294
+ use_fp8=use_fp8,
1295
+ fp8_recipe=fp8_recipe,
1296
+ fp8_group=fp8_group,
1297
+ )
1298
+ self.score_head = nn.Linear(config.d_model, 1, bias=False)
1299
+ self.model_family = f"{config.model_type}_reward_model"
1300
+ self.apply(self._init_weights)
1301
+
1302
+ def reset_perf_counters(self) -> None:
1303
+ self.backbone.reset_perf_counters()
1304
+
1305
+ def get_perf_counters(self) -> dict[str, int]:
1306
+ return self.backbone.get_perf_counters()
1307
+
1308
+ def _init_weights(self, module: nn.Module) -> None:
1309
+ if isinstance(module, MetisLinear):
1310
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1311
+ if getattr(module.impl, "bias", None) is not None:
1312
+ nn.init.zeros_(module.impl.bias)
1313
+ elif isinstance(module, nn.Linear):
1314
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1315
+ if module.bias is not None:
1316
+ nn.init.zeros_(module.bias)
1317
+ elif isinstance(module, nn.Embedding):
1318
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1319
+
1320
+ def forward(
1321
+ self,
1322
+ input_ids: torch.Tensor,
1323
+ *,
1324
+ attention_mask: torch.Tensor | None = None,
1325
+ position_ids: torch.Tensor | None = None,
1326
+ is_first_microbatch: bool | None = None,
1327
+ **_: Any,
1328
+ ) -> MetisRewardOutput:
1329
+ outputs = self.backbone(
1330
+ input_ids,
1331
+ attention_mask=attention_mask,
1332
+ position_ids=position_ids,
1333
+ is_first_microbatch=is_first_microbatch,
1334
+ )
1335
+ hidden_states = outputs["final_hidden"]
1336
+ batch_size = hidden_states.size(0)
1337
+ if attention_mask is None:
1338
+ last_indices = torch.full(
1339
+ (batch_size,),
1340
+ hidden_states.size(1) - 1,
1341
+ device=hidden_states.device,
1342
+ dtype=torch.long,
1343
+ )
1344
+ else:
1345
+ last_indices = attention_mask.to(torch.long).sum(dim=-1).clamp_min(1) - 1
1346
+ pooled_hidden = hidden_states[torch.arange(batch_size, device=hidden_states.device), last_indices]
1347
+ rewards = self.score_head(pooled_hidden).squeeze(-1)
1348
+ return MetisRewardOutput(
1349
+ rewards=rewards,
1350
+ route_aux_loss=outputs["route_aux_loss"],
1351
+ mean_depth=outputs["mean_depth"],
1352
+ active_token_ratios=outputs["active_token_ratios"],
1353
+ )
metis_mamba/runtime.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
9
+ from safetensors.torch import load_model as load_safetensors_model
10
+ from safetensors.torch import save_model as save_safetensors_model
11
+
12
+ from .config import MetisMambaConfig
13
+ from .checkpoint_compat import filter_state_dict_for_model
14
+ from .fp8 import build_fp8_recipe
15
+ from .hybrid_runtime import load_hybrid_exported_model
16
+ from .model import MetisMoRLMHeadModel, MetisMoRRewardModel
17
+
18
+
19
+ def parse_torch_dtype(dtype_name: str | None) -> torch.dtype:
20
+ mapping = {
21
+ None: torch.float32,
22
+ "fp32": torch.float32,
23
+ "float32": torch.float32,
24
+ "fp16": torch.float16,
25
+ "float16": torch.float16,
26
+ "bf16": torch.bfloat16,
27
+ "bfloat16": torch.bfloat16,
28
+ }
29
+ if dtype_name not in mapping:
30
+ raise ValueError(f"Unsupported torch dtype: {dtype_name}")
31
+ return mapping[dtype_name]
32
+
33
+
34
+ def build_model(
35
+ config: MetisMambaConfig,
36
+ *,
37
+ device: torch.device | str | None = None,
38
+ dtype: torch.dtype | None = None,
39
+ use_fp8: bool = False,
40
+ fp8_recipe=None,
41
+ fp8_group=None,
42
+ ):
43
+ config.validate()
44
+ if use_fp8 and fp8_recipe is None:
45
+ fp8_recipe = build_fp8_recipe()
46
+ model = MetisMoRLMHeadModel(
47
+ config,
48
+ use_fp8=use_fp8,
49
+ fp8_recipe=fp8_recipe,
50
+ fp8_group=fp8_group,
51
+ )
52
+ if device is not None or dtype is not None:
53
+ model = model.to(device=device, dtype=dtype)
54
+ model.config = config
55
+ model.model_family = config.model_type
56
+ return model
57
+
58
+
59
+ def build_reward_model(
60
+ config: MetisMambaConfig,
61
+ *,
62
+ device: torch.device | str | None = None,
63
+ dtype: torch.dtype | None = None,
64
+ use_fp8: bool = False,
65
+ fp8_recipe=None,
66
+ fp8_group=None,
67
+ ):
68
+ config.validate()
69
+ if use_fp8 and fp8_recipe is None:
70
+ fp8_recipe = build_fp8_recipe()
71
+ model = MetisMoRRewardModel(
72
+ config,
73
+ use_fp8=use_fp8,
74
+ fp8_recipe=fp8_recipe,
75
+ fp8_group=fp8_group,
76
+ )
77
+ if device is not None or dtype is not None:
78
+ model = model.to(device=device, dtype=dtype)
79
+ model.config = config
80
+ return model
81
+
82
+
83
+ def load_checkpoint_model(
84
+ checkpoint_path: str | Path,
85
+ device: torch.device,
86
+ ):
87
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
88
+ config = MetisMambaConfig.from_dict(checkpoint["model_config"])
89
+ model = build_model(config)
90
+ filtered_state, _conversions = filter_state_dict_for_model(model, checkpoint["model_state_dict"])
91
+ model.load_state_dict(filtered_state, strict=False)
92
+ model.to(device)
93
+ model.eval()
94
+ return model
95
+
96
+
97
+ def load_exported_model(
98
+ model_dir: str | Path,
99
+ device: torch.device,
100
+ ):
101
+ model_dir = Path(model_dir)
102
+ raw_config = json.loads((model_dir / "config.json").read_text())
103
+ if raw_config.get("model_type") == "metis_mamba2_hybrid":
104
+ return load_hybrid_exported_model(model_dir, device)
105
+ config = MetisMambaConfig.from_dict(raw_config)
106
+ model = build_model(config)
107
+ missing, unexpected = load_safetensors_model(model, str(model_dir / "model.safetensors"), device="cpu")
108
+ if missing or unexpected:
109
+ raise RuntimeError(
110
+ f"Unexpected export load result for {model_dir}: missing={missing}, unexpected={unexpected}"
111
+ )
112
+ model.to(device)
113
+ model.eval()
114
+ return model
115
+
116
+
117
+ def export_checkpoint_to_dir(
118
+ *,
119
+ checkpoint_path: str | Path,
120
+ output_dir: str | Path,
121
+ model_only_filename: str = "model.safetensors",
122
+ ) -> dict[str, Any]:
123
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
124
+ config = MetisMambaConfig.from_dict(checkpoint["model_config"])
125
+ output_dir = Path(output_dir)
126
+ output_dir.mkdir(parents=True, exist_ok=True)
127
+ export_dtype = parse_torch_dtype(config.torch_dtype)
128
+ model = build_model(config, device="cpu", dtype=export_dtype)
129
+ state_dict = {
130
+ name: tensor.detach().to(dtype=export_dtype).cpu()
131
+ for name, tensor in checkpoint["model_state_dict"].items()
132
+ }
133
+ filtered_state, _conversions = filter_state_dict_for_model(model, state_dict)
134
+ model.load_state_dict(filtered_state, strict=False)
135
+ save_safetensors_model(model, str(output_dir / model_only_filename))
136
+ return {
137
+ "config": config.to_dict(),
138
+ "model_path": str(output_dir / model_only_filename),
139
+ }
140
+
141
+
142
+ def encode_prompt(tokenizer, prompt: str, device: torch.device) -> torch.Tensor:
143
+ prompt_ids = tokenizer.encode(prompt, add_special_tokens=False).ids
144
+ bos_id = tokenizer.token_to_id("<bos>")
145
+ if bos_id is not None:
146
+ prompt_ids = [bos_id] + prompt_ids
147
+ return torch.tensor([prompt_ids], dtype=torch.long, device=device)
148
+
149
+
150
+ @torch.no_grad()
151
+ def generate_completion(
152
+ model,
153
+ tokenizer,
154
+ prompt: str,
155
+ device: torch.device,
156
+ *,
157
+ max_new_tokens: int = 120,
158
+ temperature: float = 0.8,
159
+ top_k: int | None = 50,
160
+ ) -> str:
161
+ model.eval()
162
+ input_ids = encode_prompt(tokenizer, prompt, device)
163
+ eos_token_id = tokenizer.token_to_id("<eos>")
164
+ for _ in range(max_new_tokens):
165
+ logits = model(input_ids).logits[:, -1, :].float()
166
+ if temperature <= 0:
167
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
168
+ else:
169
+ logits = logits / max(temperature, 1e-6)
170
+ if top_k is not None and top_k > 0:
171
+ values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
172
+ logits[logits < values[:, [-1]]] = -float("inf")
173
+ probs = torch.softmax(logits, dim=-1)
174
+ next_token = torch.multinomial(probs, num_samples=1)
175
+ input_ids = torch.cat([input_ids, next_token], dim=1)
176
+ if eos_token_id is not None and int(next_token.item()) == eos_token_id:
177
+ break
178
+ if input_ids.shape[1] > model.config.block_size:
179
+ input_ids = input_ids[:, -model.config.block_size :]
180
+
181
+ return tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)
182
+
183
+
184
+ def cosine_lr(step: int, *, max_steps: int, warmup_steps: int, min_lr_factor: float = 0.0) -> float:
185
+ if max_steps <= 0:
186
+ return 1.0
187
+ if warmup_steps > 0 and step < warmup_steps:
188
+ return max(step + 1, 1) / max(warmup_steps, 1)
189
+ progress = (step - warmup_steps) / max(max_steps - warmup_steps, 1)
190
+ progress = min(max(progress, 0.0), 1.0)
191
+ cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
192
+ return min_lr_factor + (1.0 - min_lr_factor) * cosine
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b58e109bf3531dea6bdf57f75000ca04cc6a8934d6290b2f112e4c97a4a118e
3
+ size 1007558534
runtime_requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ safetensors
3
+ tokenizers
4
+
5
+ # Optional but recommended on H100/Hopper for the fast path used during training:
6
+ flash-attn
7
+ transformer-engine[pytorch]
8
+ liger-kernel
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<bos>",
3
+ "eos_token": "<eos>",
4
+ "unk_token": "<unk>",
5
+ "pad_token": "<pad>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": true,
4
+ "bos_token": "<bos>",
5
+ "eos_token": "<eos>",
6
+ "unk_token": "<unk>",
7
+ "pad_token": "<pad>",
8
+ "clean_up_tokenization_spaces": false,
9
+ "model_max_length": 1024,
10
+ "tokenizer_class": "PreTrainedTokenizerFast",
11
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}System: {{ message['content'] }}\n{% elif message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}{% if add_generation_prompt %}Assistant: {% endif %}"
12
+ }
training_summary.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "repo_id": "Lernex/Metis-1.4-base",
3
+ "name": "Metis-1.4 Base",
4
+ "release_stage": "base",
5
+ "model_type": "metis_mor_transformer",
6
+ "architecture": "metis_mor_decoder",
7
+ "estimated_params": 503772163,
8
+ "context_length": 1024,
9
+ "vocab_size": 16384,
10
+ "d_model": 1536,
11
+ "n_layer": 19,
12
+ "n_heads": 24,
13
+ "n_kv_heads": 8,
14
+ "head_dim": 64,
15
+ "hidden_act": "swiglu",
16
+ "mor_max_depth": 3,
17
+ "selected_continued_pretraining_mode": "static_sequence_mor",
18
+ "selected_checkpoint": {
19
+ "run": "metis14_static_sequence_continued_fused",
20
+ "step": 5000,
21
+ "train_loss": 3.6065,
22
+ "val_loss": 3.6026,
23
+ "perplexity": 36.69
24
+ },
25
+ "comparison_checkpoint": {
26
+ "run": "metis14_static_block_continued_probe",
27
+ "step": 5000,
28
+ "train_loss": 3.609,
29
+ "val_loss": 3.6575,
30
+ "perplexity": 38.76
31
+ },
32
+ "training_path": [
33
+ "static_dense_pretrain",
34
+ "static_sequence_mor_continued_pretrain"
35
+ ],
36
+ "precision": {
37
+ "training_compute": "FP8 with BF16 master/export weights",
38
+ "export_dtype": "bfloat16",
39
+ "attention_backend": "flash_attention_3",
40
+ "loss_impl": "liger_fused_linear_ce"
41
+ },
42
+ "artifact": {
43
+ "format": "safetensors",
44
+ "filename": "model.safetensors",
45
+ "sha256": "5b58e109bf3531dea6bdf57f75000ca04cc6a8934d6290b2f112e4c97a4a118e",
46
+ "s3_uri": "s3://lernex-metis-artifacts-151025633969-us-east-1/metis14/releases/base/model.safetensors"
47
+ }
48
+ }