SeaWolf-AI commited on
Commit
ca19627
ยท
verified ยท
1 Parent(s): 1fe984b

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +320 -0
  2. config.py +149 -0
  3. layers.py +449 -0
  4. model.py +228 -0
  5. oheng_moe.py +292 -0
  6. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AETHER-Net 0.8B โ€” Inference Test Space
4
+
5
+ Private ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.
6
+ HF Space: T4 GPU, HF_TOKEN secret ํ•„์š”
7
+
8
+ Deploy: FINAL-Bench/aether-net-test
9
+ """
10
+ import os
11
+ import sys
12
+ import time
13
+ import json
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import gradio as gr
17
+ from pathlib import Path
18
+ from huggingface_hub import hf_hub_download, snapshot_download
19
+
20
+ # โ”€โ”€ Config โ”€โ”€
21
+ MODEL_REPO = "FINAL-Bench/AETHER-Net-0.8B"
22
+ DONOR_REPO = "Qwen/Qwen3.5-0.8B" # For tokenizer
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ print(f"Device: {DEVICE}")
27
+ print(f"HF_TOKEN: {'set' if HF_TOKEN else 'NOT SET'}")
28
+
29
+ # โ”€โ”€ Download model weights from private repo โ”€โ”€
30
+ print(f"Downloading AETHER-Net weights from {MODEL_REPO}...")
31
+
32
+ model_dir = None
33
+ try:
34
+ model_dir = snapshot_download(
35
+ MODEL_REPO, token=HF_TOKEN,
36
+ allow_patterns=["model.safetensors", "config.json"],
37
+ )
38
+ print(f" Model downloaded to: {model_dir}")
39
+ except Exception as e:
40
+ print(f" Download failed: {e}")
41
+
42
+ # Source files are co-located in the same directory
43
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
44
+ sys.path.insert(0, APP_DIR)
45
+
46
+ # โ”€โ”€ Load model โ”€โ”€
47
+ MODEL = None
48
+ TOKENIZER = None
49
+
50
+
51
+ def load_model():
52
+ global MODEL, TOKENIZER
53
+
54
+ if MODEL is not None:
55
+ return True
56
+
57
+ # Load tokenizer from donor
58
+ print("Loading tokenizer...")
59
+ from transformers import AutoTokenizer
60
+ try:
61
+ TOKENIZER = AutoTokenizer.from_pretrained(
62
+ DONOR_REPO, trust_remote_code=True, token=HF_TOKEN
63
+ )
64
+ print(f" Tokenizer loaded: vocab_size={TOKENIZER.vocab_size}")
65
+ except Exception as e:
66
+ print(f" Tokenizer failed: {e}")
67
+ return False
68
+
69
+ # Load AETHER-Net
70
+ print("Loading AETHER-Net model...")
71
+ try:
72
+ from config import AetherNetConfig
73
+ from model import AetherNetModel
74
+
75
+ # Load config
76
+ config_path = Path(model_dir) / "config.json" if model_dir else None
77
+ if config_path and config_path.exists():
78
+ with open(config_path) as f:
79
+ cfg_dict = json.load(f)
80
+ # Filter valid fields
81
+ valid_fields = {k for k in AetherNetConfig.__dataclass_fields__}
82
+ filtered = {k: v for k, v in cfg_dict.items() if k in valid_fields}
83
+ config = AetherNetConfig(**filtered)
84
+ print(f" Config loaded: hidden={config.hidden_size}, layers={config.num_layers}")
85
+ else:
86
+ print(" No config.json, using defaults")
87
+ config = AetherNetConfig(
88
+ hidden_size=1024, intermediate_size=3584,
89
+ num_layers=25, num_attention_heads=16, num_kv_heads=2,
90
+ head_dim=64, vocab_size=248320,
91
+ max_position_embeddings=4096,
92
+ expert_intermediate_size=716,
93
+ overcome_gate_hidden=64,
94
+ sliding_window_size=1024,
95
+ gdn_state_size=64, mamba2_state_size=64,
96
+ tie_word_embeddings=True,
97
+ )
98
+
99
+ model = AetherNetModel(config)
100
+
101
+ # Load weights
102
+ weights_path = Path(model_dir) / "model.safetensors" if model_dir else None
103
+ if weights_path and weights_path.exists():
104
+ from safetensors.torch import load_file
105
+ state = load_file(str(weights_path), device="cpu")
106
+ model.load_state_dict(state, strict=False)
107
+ print(f" Weights loaded: {len(state)} tensors")
108
+ else:
109
+ print(" โš ๏ธ No weights found, using random init")
110
+
111
+ model = model.to(DEVICE).eval()
112
+ MODEL = model
113
+
114
+ params = sum(p.numel() for p in model.parameters())
115
+ mem = params * 2 / 1e9 # BF16 estimate
116
+ print(f" Model ready: {params:,} params (~{mem:.1f}GB)")
117
+ return True
118
+
119
+ except Exception as e:
120
+ import traceback
121
+ print(f" Model load failed: {e}")
122
+ traceback.print_exc()
123
+ return False
124
+
125
+
126
+ # โ”€โ”€ Generation โ”€โ”€
127
+ @torch.no_grad()
128
+ def generate(prompt, max_tokens=128, temperature=0.8, top_k=50, top_p=0.9):
129
+ """Generate text from prompt."""
130
+ if MODEL is None:
131
+ success = load_model()
132
+ if not success:
133
+ return "โŒ Model failed to load. Check logs."
134
+
135
+ # Tokenize
136
+ input_ids = TOKENIZER.encode(prompt, return_tensors="pt").to(DEVICE)
137
+ generated = input_ids.clone()
138
+
139
+ t0 = time.time()
140
+
141
+ for i in range(max_tokens):
142
+ # Truncate to max position
143
+ if generated.shape[1] > 4096:
144
+ generated = generated[:, -4096:]
145
+
146
+ outputs = MODEL(input_ids=generated)
147
+ logits = outputs["logits"][:, -1, :]
148
+
149
+ # Temperature
150
+ if temperature > 0:
151
+ logits = logits / temperature
152
+
153
+ # Top-k
154
+ if top_k > 0:
155
+ values, _ = torch.topk(logits, top_k)
156
+ min_val = values[:, -1].unsqueeze(-1)
157
+ logits = torch.where(logits < min_val, torch.full_like(logits, -float('inf')), logits)
158
+
159
+ # Top-p (nucleus)
160
+ if top_p < 1.0:
161
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
162
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
163
+ mask = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
164
+ sorted_logits[mask] = -float('inf')
165
+ logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
166
+
167
+ probs = F.softmax(logits, dim=-1)
168
+ next_token = torch.multinomial(probs, num_samples=1)
169
+ else:
170
+ next_token = logits.argmax(dim=-1, keepdim=True)
171
+
172
+ generated = torch.cat([generated, next_token], dim=-1)
173
+
174
+ # EOS check
175
+ if next_token.item() == TOKENIZER.eos_token_id:
176
+ break
177
+
178
+ elapsed = time.time() - t0
179
+ tokens_generated = generated.shape[1] - input_ids.shape[1]
180
+ tps = tokens_generated / elapsed if elapsed > 0 else 0
181
+
182
+ output_text = TOKENIZER.decode(generated[0], skip_special_tokens=True)
183
+ stats = f"\n\n---\n๐Ÿ“Š {tokens_generated} tokens | {tps:.1f} tok/s | {elapsed:.2f}s"
184
+
185
+ return output_text + stats
186
+
187
+
188
+ def get_model_info():
189
+ """Return model architecture info."""
190
+ if MODEL is None:
191
+ load_model()
192
+
193
+ if MODEL is None:
194
+ return "Model not loaded"
195
+
196
+ info = "## AETHER-Net 0.8B โ€” Architecture Info\n\n"
197
+ info += f"| Item | Value |\n|---|---|\n"
198
+ info += f"| Device | {DEVICE} |\n"
199
+ info += f"| Parameters | {sum(p.numel() for p in MODEL.parameters()):,} |\n"
200
+ info += f"| Layers | {len(MODEL.layers)} |\n"
201
+ info += f"| Vocab | {MODEL.config.vocab_size:,} |\n"
202
+ info += f"| Hidden | {MODEL.config.hidden_size} |\n"
203
+
204
+ # Layer types
205
+ from config import LAYER_TYPES, LAYER_TO_ELEMENT, ELEMENTS
206
+ info += f"\n### Layer Map\n\n"
207
+ info += "| Layer | Type | Element |\n|---|---|---|\n"
208
+ for i in range(len(MODEL.layers)):
209
+ lt = LAYER_TYPES[i]
210
+ elem = LAYER_TO_ELEMENT[i]
211
+ info += f"| {i} | {lt.upper()} | {elem} |\n"
212
+
213
+ # Oheng status
214
+ info += f"\n### Oheng Status\n\n"
215
+ for elem in ELEMENTS:
216
+ layers = [i for i in range(25) if LAYER_TO_ELEMENT[i] == elem]
217
+ alphas = []
218
+ for li in layers:
219
+ gb = MODEL.layers[li].moe.generate_boost
220
+ if gb is not None:
221
+ a = torch.sigmoid(gb.alpha).detach()
222
+ eidx = ELEMENTS.index(elem)
223
+ if eidx < a.shape[0]:
224
+ alphas.append(a[eidx].item())
225
+ avg = sum(alphas) / len(alphas) if alphas else 0
226
+ info += f"- {elem}: ฮฑ={avg:.4f}\n"
227
+
228
+ return info
229
+
230
+
231
+ # โ”€โ”€ Gradio UI โ”€โ”€
232
+ TITLE = """
233
+ <div style="text-align:center; padding:15px 0;">
234
+ <h1>๐ŸŒŒ AETHER-Net 0.8B โ€” Inference Test</h1>
235
+ <p style="color:#666;">Cross-Architecture Knowledge Distillation from Qwen3.5-0.8B</p>
236
+ <p style="color:#999; font-size:0.9em;">5ร—5 Magic Square | Oheng MoE | 5 Attention Types</p>
237
+ </div>
238
+ """
239
+
240
+ with gr.Blocks(title="AETHER-Net Test") as app:
241
+ gr.HTML(TITLE)
242
+
243
+ with gr.Tabs():
244
+ with gr.Tab("๐Ÿ’ฌ Generate"):
245
+ gr.Markdown("ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด AETHER-Net์ด ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.")
246
+
247
+ with gr.Row():
248
+ with gr.Column(scale=3):
249
+ prompt = gr.Textbox(
250
+ label="Prompt",
251
+ placeholder="Enter your prompt here...",
252
+ lines=3,
253
+ value="The theory of relativity explains that"
254
+ )
255
+ with gr.Column(scale=1):
256
+ max_tokens = gr.Slider(16, 512, value=128, step=16, label="Max Tokens")
257
+ temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.1, label="Temperature")
258
+ top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K")
259
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
260
+
261
+ gen_btn = gr.Button("๐Ÿš€ Generate", variant="primary", size="lg")
262
+ output = gr.Textbox(label="Output", lines=12, interactive=False)
263
+
264
+ gen_btn.click(
265
+ fn=generate,
266
+ inputs=[prompt, max_tokens, temperature, top_k, top_p],
267
+ outputs=output,
268
+ )
269
+
270
+ gr.Markdown("### Quick Prompts")
271
+ examples = gr.Examples(
272
+ examples=[
273
+ ["The theory of relativity explains that"],
274
+ ["In Python, the most efficient way to sort a list is"],
275
+ ["The five elements of nature are"],
276
+ ["Artificial general intelligence requires"],
277
+ ["ํ•œ๊ตญ์˜ ์ˆ˜๋„๋Š”"],
278
+ ["def fibonacci(n):"],
279
+ ],
280
+ inputs=prompt,
281
+ )
282
+
283
+ with gr.Tab("๐Ÿ” Model Info"):
284
+ info_btn = gr.Button("Load Model Info", variant="primary")
285
+ info_output = gr.Markdown()
286
+ info_btn.click(fn=get_model_info, outputs=info_output)
287
+
288
+ with gr.Tab("โ„น๏ธ About"):
289
+ gr.Markdown("""
290
+ ## AETHER-Net 0.8B
291
+
292
+ **Cross-Architecture Knowledge Distillation from Qwen3.5-0.8B**
293
+
294
+ ### Method
295
+ - **Weight Transplant**: Qwen3.5-0.8B โ†’ AETHER-Net (5ร—5 Magic Square layout)
296
+ - **3-Stage MOHAWK Distillation**: KLD โ†’ Hidden Alignment โ†’ Oheng Regularization
297
+ - **Cost**: ~$0 (CPU-only, 100 steps demo)
298
+
299
+ ### Architecture
300
+ - 25 Layers: 5 attention types ร— 5 elements
301
+ - GDN, Full, Mamba2, Sliding Window, Cross Attention
302
+ - Oheng MoE: 25 experts, ์ƒ์ƒ(Generate) + ์ƒ๊ทน(Overcome)
303
+
304
+ ### Source
305
+ - Model: [FINAL-Bench/AETHER-Net-0.8B](https://huggingface.co/FINAL-Bench/AETHER-Net-0.8B) (private)
306
+ - Space: [FINAL-Bench/agi-model-gen](https://huggingface.co/spaces/FINAL-Bench/agi-model-gen)
307
+
308
+ ---
309
+ ยฉ 2026 VIDRAFT / Ginigen AI
310
+ """)
311
+
312
+
313
+ # โ”€โ”€ Preload model on startup โ”€โ”€
314
+ print("\n=== Pre-loading model ===")
315
+ load_model()
316
+ print("=== Ready ===\n")
317
+
318
+
319
+ if __name__ == "__main__":
320
+ app.launch()
config.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AETHER-Net Configuration
3
+ Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network
4
+
5
+ 5ร—5 Latin Orthogonal Magic Square Layout + Oheng(ไบ”่กŒ) MoE Routing
6
+ """
7
+ from dataclasses import dataclass, field
8
+ from typing import List, Tuple
9
+
10
+ # โ”€โ”€ 5ร—5 Latin Orthogonal Magic Square โ”€โ”€
11
+ # Each row (element group) and each column (phase) contains
12
+ # exactly one of each attention type โ†’ zero carry-over bias.
13
+ MAGIC_SQUARE = [
14
+ # Phase1 Phase2 Phase3 Phase4 Phase5
15
+ ["gdn", "full", "mamba2", "slide", "cross"], # ๆœจ Wood
16
+ ["slide", "gdn", "full", "cross", "mamba2"], # ็ซ Fire
17
+ ["full", "cross", "slide", "mamba2", "gdn"], # ๅœŸ Earth
18
+ ["mamba2", "slide", "cross", "gdn", "full"], # ้‡‘ Metal
19
+ ["cross", "mamba2", "gdn", "full", "slide"], # ๆฐด Water
20
+ ]
21
+
22
+ # Flatten to 25-layer sequence (row-major)
23
+ LAYER_TYPES = [t for row in MAGIC_SQUARE for t in row]
24
+
25
+ # โ”€โ”€ Oheng (ไบ”่กŒ) Element System โ”€โ”€
26
+ ELEMENTS = ["wood", "fire", "earth", "metal", "water"]
27
+
28
+ # ์ƒ์ƒ (Generate): ๆœจโ†’็ซโ†’ๅœŸโ†’้‡‘โ†’ๆฐดโ†’ๆœจ
29
+ GENERATE = {"wood": "fire", "fire": "earth", "earth": "metal", "metal": "water", "water": "wood"}
30
+ GENERATE_REVERSE = {v: k for k, v in GENERATE.items()}
31
+
32
+ # ์ƒ๊ทน (Overcome): ๆœจโŠฃๅœŸ, ๅœŸโŠฃๆฐด, ๆฐดโŠฃ็ซ, ็ซโŠฃ้‡‘, ้‡‘โŠฃๆœจ
33
+ OVERCOME = {"wood": "earth", "earth": "water", "water": "fire", "fire": "metal", "metal": "wood"}
34
+ OVERCOME_REVERSE = {v: k for k, v in OVERCOME.items()}
35
+
36
+ # Element โ†’ Layer indices (0-based)
37
+ ELEMENT_LAYERS = {
38
+ "wood": [0, 1, 2, 3, 4],
39
+ "fire": [5, 6, 7, 8, 9],
40
+ "earth": [10, 11, 12, 13, 14],
41
+ "metal": [15, 16, 17, 18, 19],
42
+ "water": [20, 21, 22, 23, 24],
43
+ }
44
+
45
+ # Element โ†’ Expert indices (0-based, 5 experts per element)
46
+ ELEMENT_EXPERTS = {
47
+ "wood": [0, 1, 2, 3, 4],
48
+ "fire": [5, 6, 7, 8, 9],
49
+ "earth": [10, 11, 12, 13, 14],
50
+ "metal": [15, 16, 17, 18, 19],
51
+ "water": [20, 21, 22, 23, 24],
52
+ }
53
+
54
+ # Layer index โ†’ element name
55
+ LAYER_TO_ELEMENT = {}
56
+ for elem, indices in ELEMENT_LAYERS.items():
57
+ for idx in indices:
58
+ LAYER_TO_ELEMENT[idx] = elem
59
+
60
+
61
+ @dataclass
62
+ class AetherNetConfig:
63
+ """Configuration for AETHER-Net model."""
64
+
65
+ # โ”€โ”€ Model dimensions โ”€โ”€
66
+ hidden_size: int = 4096
67
+ intermediate_size: int = 11008 # FFN intermediate (SwiGLU)
68
+ num_layers: int = 25
69
+ num_attention_heads: int = 32
70
+ num_kv_heads: int = 8 # GQA for Full Attention layers
71
+ head_dim: int = 128 # hidden_size // num_attention_heads
72
+ vocab_size: int = 151936 # Qwen tokenizer
73
+ max_position_embeddings: int = 262144
74
+ rope_theta: float = 10000000.0
75
+
76
+ # โ”€โ”€ Layer schedule (from magic square) โ”€โ”€
77
+ layer_types: List[str] = field(default_factory=lambda: LAYER_TYPES)
78
+
79
+ # โ”€โ”€ MoE Configuration โ”€โ”€
80
+ num_experts: int = 25
81
+ num_experts_per_group: int = 5
82
+ num_element_groups: int = 5
83
+ top_k: int = 2
84
+ num_shared_experts: int = 1
85
+ expert_intermediate_size: int = 2752 # intermediate_size // 4 (per expert)
86
+ moe_jitter_eps: float = 0.01
87
+
88
+ # โ”€โ”€ Oheng (ไบ”่กŒ) routing โ”€โ”€
89
+ use_generate_boost: bool = True
90
+ use_overcome_gate: bool = True
91
+ generate_alpha_init: float = 0.1 # learnable soft scalar
92
+ overcome_gate_hidden: int = 256 # critic head hidden dim
93
+
94
+ # โ”€โ”€ Attention-specific โ”€โ”€
95
+ sliding_window_size: int = 4096
96
+ gdn_state_size: int = 128 # Gated DeltaNet state dimension
97
+ mamba2_state_size: int = 128
98
+ mamba2_conv_size: int = 4
99
+ mamba2_expand: int = 2
100
+
101
+ # โ”€โ”€ Training / Inference โ”€โ”€
102
+ rms_norm_eps: float = 1e-6
103
+ initializer_range: float = 0.02
104
+ tie_word_embeddings: bool = False
105
+ use_cache: bool = True
106
+ torch_dtype: str = "bfloat16"
107
+
108
+ # โ”€โ”€ Donor transplant info (metadata) โ”€โ”€
109
+ primary_donor: str = "Qwen/Qwen3.5-27B"
110
+ secondary_donor: str = "meta-llama/Llama-3.1-8B"
111
+
112
+ def get_layer_type(self, layer_idx: int) -> str:
113
+ return self.layer_types[layer_idx]
114
+
115
+ def get_layer_element(self, layer_idx: int) -> str:
116
+ return LAYER_TO_ELEMENT[layer_idx]
117
+
118
+ def get_element_expert_range(self, element: str) -> Tuple[int, int]:
119
+ indices = ELEMENT_EXPERTS[element]
120
+ return (indices[0], indices[-1] + 1)
121
+
122
+ def summary(self) -> str:
123
+ type_counts = {}
124
+ for t in self.layer_types:
125
+ type_counts[t] = type_counts.get(t, 0) + 1
126
+ total_params_b = (
127
+ self.num_experts * self.expert_intermediate_size * self.hidden_size * 3 * 2 # experts
128
+ + self.num_layers * self.hidden_size * self.hidden_size * 4 # attention projections
129
+ + self.vocab_size * self.hidden_size * 2 # embeddings
130
+ ) / 1e9
131
+ active_params_b = total_params_b * (self.top_k + self.num_shared_experts) / self.num_experts_per_group
132
+ lines = [
133
+ "โ•" * 60,
134
+ " AETHER-Net Architecture Summary",
135
+ "โ•" * 60,
136
+ f" Layers: {self.num_layers} (5ร—5 magic square)",
137
+ f" Hidden dim: {self.hidden_size}",
138
+ f" Attention mix: {type_counts}",
139
+ f" MoE: {self.num_experts} experts / {self.num_element_groups} groups / top-{self.top_k}",
140
+ f" Est. total: ~{total_params_b:.1f}B params",
141
+ f" Est. active: ~{active_params_b:.1f}B params",
142
+ f" Context: {self.max_position_embeddings:,} tokens",
143
+ f" Oheng generate: {self.use_generate_boost} (ฮฑ={self.generate_alpha_init})",
144
+ f" Oheng overcome: {self.use_overcome_gate}",
145
+ f" Primary donor: {self.primary_donor}",
146
+ f" Secondary donor:{self.secondary_donor}",
147
+ "โ•" * 60,
148
+ ]
149
+ return "\n".join(lines)
layers.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AETHER-Net Attention Layers
3
+ 5 types: GDN, Full, Mamba2, Sliding Window, Cross Attention
4
+
5
+ Each layer follows the same interface:
6
+ forward(hidden_states, attention_mask=None, position_ids=None, **kwargs) -> hidden_states
7
+ """
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional, Tuple
13
+
14
+
15
+ class RMSNorm(nn.Module):
16
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
17
+ super().__init__()
18
+ self.weight = nn.Parameter(torch.ones(hidden_size))
19
+ self.eps = eps
20
+
21
+ def forward(self, x):
22
+ variance = x.float().pow(2).mean(-1, keepdim=True)
23
+ x = x * torch.rsqrt(variance + self.eps)
24
+ return (self.weight * x).to(x.dtype)
25
+
26
+
27
+ def rotate_half(x):
28
+ x1, x2 = x.chunk(2, dim=-1)
29
+ return torch.cat((-x2, x1), dim=-1)
30
+
31
+
32
+ def apply_rotary_pos_emb(q, k, cos, sin):
33
+ q_embed = (q * cos) + (rotate_half(q) * sin)
34
+ k_embed = (k * cos) + (rotate_half(k) * sin)
35
+ return q_embed, k_embed
36
+
37
+
38
+ class RotaryEmbedding(nn.Module):
39
+ def __init__(self, dim: int, max_seq_len: int = 262144, theta: float = 10000000.0):
40
+ super().__init__()
41
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
42
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
43
+ self.max_seq_len = max_seq_len
44
+
45
+ def forward(self, x, position_ids):
46
+ # position_ids: [B, L] โ†’ take first batch (all same for standard positions)
47
+ pos = position_ids[0] if position_ids.dim() == 2 else position_ids
48
+ freqs = torch.outer(pos.float(), self.inv_freq.to(pos.device))
49
+ emb = torch.cat((freqs, freqs), dim=-1)
50
+ return emb.cos().unsqueeze(0), emb.sin().unsqueeze(0)
51
+
52
+
53
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
54
+ # 1. FULL ATTENTION (Softmax, GQA, RoPE) โ€” O(nยฒ)
55
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
56
+ class FullAttention(nn.Module):
57
+ """Standard grouped-query attention with RoPE.
58
+ Kept for 5 layers โ€” provides precise token-to-token reasoning.
59
+ These layers maintain KV cache."""
60
+
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.num_heads = config.num_attention_heads
64
+ self.num_kv_heads = config.num_kv_heads
65
+ self.head_dim = config.head_dim
66
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
67
+
68
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
69
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
70
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
71
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
72
+
73
+ # Output gate (Qwen3.5 style gated attention)
74
+ self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
75
+
76
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
77
+
78
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
79
+ B, L, _ = hidden_states.shape
80
+
81
+ q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
82
+ k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
83
+ v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
84
+
85
+ # RoPE
86
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
87
+ cos = cos.unsqueeze(1)
88
+ sin = sin.unsqueeze(1)
89
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
90
+
91
+ # GQA: expand KV heads
92
+ if self.num_kv_groups > 1:
93
+ k = k.repeat_interleave(self.num_kv_groups, dim=1)
94
+ v = v.repeat_interleave(self.num_kv_groups, dim=1)
95
+
96
+ # Scaled dot-product attention
97
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
98
+
99
+ # Causal mask
100
+ causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1)
101
+ attn = attn + causal.unsqueeze(0).unsqueeze(0)
102
+ if attention_mask is not None:
103
+ attn = attn + attention_mask
104
+
105
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
106
+ out = torch.matmul(attn, v)
107
+ out = out.transpose(1, 2).contiguous().view(B, L, -1)
108
+
109
+ # Output gating
110
+ gate = torch.sigmoid(self.gate(hidden_states))
111
+ out = out * gate
112
+
113
+ return self.o_proj(out)
114
+
115
+
116
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
117
+ # 2. GATED DELTANET (GDN) โ€” O(n) linear time
118
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
119
+ class GatedDeltaNet(nn.Module):
120
+ """Gated DeltaNet: Mamba-style gating + DeltaNet fast-weight update.
121
+ Core linear attention mechanism โ€” 10 layers (40% of model).
122
+
123
+ Implements: M_t = ฮฑ_t * M_{t-1} * (I - k_t * q_t^T) + k_t * v_t^T
124
+ with SiLU output gating for gradient flow stability.
125
+
126
+ Weight transplant: Q,K,V projections map directly from Qwen3.5 GDN layers.
127
+ """
128
+
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.hidden_size = config.hidden_size
132
+ self.num_heads = config.num_attention_heads
133
+ self.head_dim = config.head_dim
134
+ self.state_size = config.gdn_state_size
135
+
136
+ # Input projections (transplantable from Qwen3.5 GDN)
137
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
138
+ self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
139
+ self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
140
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
141
+
142
+ # Decay gate (ฮฑ): controls memory decay speed
143
+ self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
144
+
145
+ # Update gate (ฮฒ): controls state update strength
146
+ self.beta_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
147
+
148
+ # Output gate (SiLU activation for gradient stability)
149
+ self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
150
+
151
+ # Short convolution for local context (replaces positional encoding)
152
+ self.conv1d = nn.Conv1d(
153
+ in_channels=config.hidden_size,
154
+ out_channels=config.hidden_size,
155
+ kernel_size=4, padding=3, groups=config.hidden_size, bias=True
156
+ )
157
+
158
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
159
+ B, L, D = hidden_states.shape
160
+
161
+ # Local context mixing via causal conv1d
162
+ conv_out = self.conv1d(hidden_states.transpose(1, 2))[..., :L].transpose(1, 2)
163
+
164
+ q = self.q_proj(conv_out).view(B, L, self.num_heads, self.head_dim)
165
+ k = self.k_proj(conv_out).view(B, L, self.num_heads, self.head_dim)
166
+ v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim)
167
+
168
+ # L2 normalize Q, K (replaces softmax normalization)
169
+ q = F.normalize(q, p=2, dim=-1)
170
+ k = F.normalize(k, p=2, dim=-1)
171
+
172
+ # Decay and update gates
173
+ alpha = torch.sigmoid(self.decay_proj(hidden_states)).unsqueeze(-1) # [B, L, H, 1]
174
+ beta = torch.sigmoid(self.beta_proj(hidden_states)).unsqueeze(-1)
175
+
176
+ # Recurrent scan with delta rule
177
+ # M_t = ฮฑ * M_{t-1} * (I - ฮฒ * k * q^T) + ฮฒ * k * v^T
178
+ # For efficiency, compute as: o_t = q^T @ M_t
179
+ outputs = []
180
+ state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim,
181
+ device=hidden_states.device, dtype=hidden_states.dtype)
182
+
183
+ for t in range(L):
184
+ q_t = q[:, t] # [B, H, D]
185
+ k_t = k[:, t]
186
+ v_t = v[:, t]
187
+ a_t = alpha[:, t] # [B, H, 1]
188
+ b_t = beta[:, t]
189
+
190
+ # Delta rule update
191
+ # Erase: state = ฮฑ * state * (I - ฮฒ * k * q^T)
192
+ # Write: state += ฮฒ * k * v^T
193
+ erase = torch.einsum('bhd,bhe->bhde', k_t * b_t, q_t)
194
+ write = torch.einsum('bhd,bhe->bhde', k_t * b_t, v_t)
195
+ state = a_t.unsqueeze(-1) * (state - state * erase) + write
196
+
197
+ # Read: o_t = q^T @ state
198
+ o_t = torch.einsum('bhd,bhde->bhe', q_t, state)
199
+ outputs.append(o_t)
200
+
201
+ out = torch.stack(outputs, dim=1) # [B, L, H, D]
202
+ out = out.reshape(B, L, -1)
203
+
204
+ # Output gating with SiLU
205
+ gate = F.silu(self.gate(hidden_states))
206
+ out = out * gate
207
+
208
+ return self.o_proj(out)
209
+
210
+
211
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
212
+ # 3. MAMBA2 โ€” O(n) with SSM state-space duality
213
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
214
+ class Mamba2Block(nn.Module):
215
+ """Mamba-2 block with Structured State Space Duality.
216
+ 5 layers โ€” provides state compression for memory efficiency.
217
+
218
+ Weight transplant: Via MOHAWK SSD duality from Llama-3.1 Q,K,V โ†’ C,B,X.
219
+ """
220
+
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.hidden_size = config.hidden_size
224
+ expand = config.mamba2_expand
225
+ self.inner_size = config.hidden_size * expand
226
+ self.state_size = config.mamba2_state_size
227
+ self.conv_size = config.mamba2_conv_size
228
+ self.num_heads = config.num_attention_heads
229
+
230
+ # Input projection: x โ†’ (z, x_ssm) split
231
+ self.in_proj = nn.Linear(config.hidden_size, self.inner_size * 2, bias=False)
232
+
233
+ # Causal conv1d
234
+ self.conv1d = nn.Conv1d(
235
+ self.inner_size, self.inner_size,
236
+ kernel_size=self.conv_size, padding=self.conv_size - 1,
237
+ groups=self.inner_size, bias=True
238
+ )
239
+
240
+ # SSM parameters
241
+ self.dt_proj = nn.Linear(self.inner_size, self.num_heads, bias=True)
242
+ self.A_log = nn.Parameter(torch.log(torch.arange(1, self.num_heads + 1, dtype=torch.float32)))
243
+ self.D = nn.Parameter(torch.ones(self.num_heads))
244
+
245
+ # B, C projections (state-space)
246
+ head_dim_ssm = self.inner_size // self.num_heads
247
+ self.B_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False)
248
+ self.C_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False)
249
+
250
+ # Output
251
+ self.out_proj = nn.Linear(self.inner_size, config.hidden_size, bias=False)
252
+ self.norm = RMSNorm(self.inner_size)
253
+
254
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
255
+ B, L, _ = hidden_states.shape
256
+
257
+ # Input split
258
+ zx = self.in_proj(hidden_states)
259
+ z, x = zx.chunk(2, dim=-1)
260
+
261
+ # Causal conv
262
+ x = self.conv1d(x.transpose(1, 2))[..., :L].transpose(1, 2)
263
+ x = F.silu(x)
264
+
265
+ # SSM parameters
266
+ A = -torch.exp(self.A_log) # [H]
267
+ dt = F.softplus(self.dt_proj(x)) # [B, L, H]
268
+
269
+ B_state = self.B_proj(x).view(B, L, self.num_heads, self.state_size)
270
+ C_state = self.C_proj(x).view(B, L, self.num_heads, self.state_size)
271
+
272
+ # Discretize: A_bar = exp(dt * A), B_bar = dt * B
273
+ dt_A = dt.unsqueeze(-1) * A.view(1, 1, -1, 1) # [B, L, H, 1]
274
+ A_bar = torch.exp(dt_A)
275
+ B_bar = dt.unsqueeze(-1) * B_state # [B, L, H, N]
276
+
277
+ # Selective scan (sequential for correctness; replace with FLA parallel kernel)
278
+ head_dim = self.inner_size // self.num_heads
279
+ x_heads = x.view(B, L, self.num_heads, head_dim)
280
+
281
+ outputs = []
282
+ state = torch.zeros(B, self.num_heads, self.state_size, device=x.device, dtype=x.dtype)
283
+
284
+ for t in range(L):
285
+ state = A_bar[:, t] * state + B_bar[:, t] * x_heads[:, t, :, :1].expand_as(B_bar[:, t])
286
+ y_t = torch.sum(state * C_state[:, t], dim=-1) # [B, H]
287
+ outputs.append(y_t)
288
+
289
+ y = torch.stack(outputs, dim=1) # [B, L, H]
290
+
291
+ # Skip connection with D
292
+ y = y + self.D.view(1, 1, -1) * x.view(B, L, self.num_heads, head_dim).mean(-1)
293
+
294
+ # Expand back and gate with z
295
+ y = y.unsqueeze(-1).expand(-1, -1, -1, head_dim).reshape(B, L, self.inner_size)
296
+ y = self.norm(y)
297
+ y = y * F.silu(z)
298
+
299
+ return self.out_proj(y)
300
+
301
+
302
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
303
+ # 4. SLIDING WINDOW ATTENTION โ€” O(n * w)
304
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
305
+ class SlidingWindowAttention(nn.Module):
306
+ """Sliding window attention for local pattern capture.
307
+ 5 layers โ€” complements GDN's global view with fine-grained local context."""
308
+
309
+ def __init__(self, config):
310
+ super().__init__()
311
+ self.num_heads = config.num_attention_heads
312
+ self.num_kv_heads = config.num_kv_heads
313
+ self.head_dim = config.head_dim
314
+ self.window_size = config.sliding_window_size
315
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
316
+
317
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
318
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
319
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
320
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
321
+ self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
322
+
323
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
324
+
325
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
326
+ B, L, _ = hidden_states.shape
327
+
328
+ q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
329
+ k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
330
+ v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
331
+
332
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
333
+ cos = cos.unsqueeze(1)
334
+ sin = sin.unsqueeze(1)
335
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
336
+
337
+ if self.num_kv_groups > 1:
338
+ k = k.repeat_interleave(self.num_kv_groups, dim=1)
339
+ v = v.repeat_interleave(self.num_kv_groups, dim=1)
340
+
341
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
342
+
343
+ # Sliding window + causal mask
344
+ mask = torch.ones(L, L, device=attn.device, dtype=torch.bool)
345
+ mask = torch.triu(mask, diagonal=1) # causal
346
+ mask = mask | torch.tril(torch.ones_like(mask), diagonal=-self.window_size) # window
347
+ attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
348
+
349
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
350
+ out = torch.matmul(attn, v)
351
+ out = out.transpose(1, 2).contiguous().view(B, L, -1)
352
+
353
+ gate = torch.sigmoid(self.gate(hidden_states))
354
+ out = out * gate
355
+
356
+ return self.o_proj(out)
357
+
358
+
359
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
360
+ # 5. CROSS ATTENTION โ€” for multimodal / tool bridging
361
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
362
+ class CrossAttention(nn.Module):
363
+ """Cross attention for PROMETHEUS (world model) and HEPHAESTUS (embodiment) connection.
364
+ 5 layers โ€” bridges AETHER-Net to external modalities.
365
+ When no external context: falls back to self-attention with gating."""
366
+
367
+ def __init__(self, config):
368
+ super().__init__()
369
+ self.num_heads = config.num_attention_heads
370
+ self.head_dim = config.head_dim
371
+
372
+ # Self-attention path (default when no external context)
373
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
374
+ self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
375
+ self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
376
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
377
+
378
+ # Cross-attention path (when external context available)
379
+ self.cross_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
380
+ self.cross_v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
381
+
382
+ # Modality gate: lerp between self and cross
383
+ self.modality_gate = nn.Linear(config.hidden_size, 1, bias=True)
384
+ nn.init.constant_(self.modality_gate.bias, -2.0) # default: mostly self-attention
385
+
386
+ self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
387
+
388
+ def forward(self, hidden_states, attention_mask=None, position_ids=None,
389
+ encoder_hidden_states=None, **kwargs):
390
+ B, L, _ = hidden_states.shape
391
+
392
+ q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
393
+
394
+ if encoder_hidden_states is not None:
395
+ # Cross-attention mode
396
+ k_cross = self.cross_k_proj(encoder_hidden_states).view(
397
+ B, -1, self.num_heads, self.head_dim).transpose(1, 2)
398
+ v_cross = self.cross_v_proj(encoder_hidden_states).view(
399
+ B, -1, self.num_heads, self.head_dim).transpose(1, 2)
400
+
401
+ attn_cross = torch.matmul(q, k_cross.transpose(-2, -1)) / math.sqrt(self.head_dim)
402
+ attn_cross = F.softmax(attn_cross, dim=-1, dtype=torch.float32).to(q.dtype)
403
+ out_cross = torch.matmul(attn_cross, v_cross)
404
+ out_cross = out_cross.transpose(1, 2).contiguous().view(B, L, -1)
405
+
406
+ # Self-attention path (always runs)
407
+ k_self = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
408
+ v_self = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
409
+ attn_self = torch.matmul(q, k_self.transpose(-2, -1)) / math.sqrt(self.head_dim)
410
+ causal = torch.triu(torch.full((L, L), float('-inf'), device=attn_self.device), diagonal=1)
411
+ attn_self = attn_self + causal.unsqueeze(0).unsqueeze(0)
412
+ attn_self = F.softmax(attn_self, dim=-1, dtype=torch.float32).to(q.dtype)
413
+ out_self = torch.matmul(attn_self, v_self).transpose(1, 2).contiguous().view(B, L, -1)
414
+
415
+ # Blend via modality gate
416
+ mg = torch.sigmoid(self.modality_gate(hidden_states))
417
+ out = mg * out_cross + (1 - mg) * out_self
418
+ else:
419
+ # Pure self-attention fallback
420
+ k = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
421
+ v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
422
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
423
+ causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1)
424
+ attn = attn + causal.unsqueeze(0).unsqueeze(0)
425
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
426
+ out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, -1)
427
+
428
+ gate = torch.sigmoid(self.gate(hidden_states))
429
+ out = out * gate
430
+
431
+ return self.o_proj(out)
432
+
433
+
434
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
435
+ # Factory
436
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
437
+ ATTENTION_CLASSES = {
438
+ "gdn": GatedDeltaNet,
439
+ "full": FullAttention,
440
+ "mamba2": Mamba2Block,
441
+ "slide": SlidingWindowAttention,
442
+ "cross": CrossAttention,
443
+ }
444
+
445
+ def build_attention(layer_type: str, config):
446
+ cls = ATTENTION_CLASSES.get(layer_type)
447
+ if cls is None:
448
+ raise ValueError(f"Unknown attention type: {layer_type}. Choose from {list(ATTENTION_CLASSES.keys())}")
449
+ return cls(config)
model.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AETHER-Net: Main Model
3
+ Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network
4
+
5
+ 25-layer hybrid LLM with 5ร—5 Latin orthogonal magic square layout
6
+ and Oheng (ไบ”่กŒ) MoE routing.
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ from config import AetherNetConfig, ELEMENTS, LAYER_TO_ELEMENT, ELEMENT_LAYERS
13
+ from layers import RMSNorm, build_attention
14
+ from oheng_moe import OhengMoE
15
+
16
+
17
+ class AetherNetBlock(nn.Module):
18
+ """Single AETHER-Net transformer block.
19
+
20
+ Structure:
21
+ x โ†’ RMSNorm โ†’ Attention โ†’ residual โ†’ RMSNorm โ†’ OhengMoE โ†’ residual โ†’ out
22
+ """
23
+
24
+ def __init__(self, config: AetherNetConfig, layer_idx: int):
25
+ super().__init__()
26
+ self.layer_idx = layer_idx
27
+ self.layer_type = config.get_layer_type(layer_idx)
28
+ self.element = config.get_layer_element(layer_idx)
29
+
30
+ # Pre-norm
31
+ self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
32
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
33
+
34
+ # Attention (type determined by magic square)
35
+ self.attention = build_attention(self.layer_type, config)
36
+
37
+ # MoE FFN with Oheng routing
38
+ self.moe = OhengMoE(config, layer_idx)
39
+
40
+ def forward(
41
+ self,
42
+ hidden_states: torch.Tensor,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.Tensor] = None,
45
+ element_states: Optional[Dict[str, torch.Tensor]] = None,
46
+ encoder_hidden_states: Optional[torch.Tensor] = None,
47
+ ) -> torch.Tensor:
48
+ # Attention block with residual
49
+ residual = hidden_states
50
+ hidden_states = self.input_layernorm(hidden_states)
51
+ hidden_states = self.attention(
52
+ hidden_states,
53
+ attention_mask=attention_mask,
54
+ position_ids=position_ids,
55
+ encoder_hidden_states=encoder_hidden_states,
56
+ )
57
+ hidden_states = residual + hidden_states
58
+
59
+ # MoE FFN block with residual
60
+ residual = hidden_states
61
+ hidden_states = self.post_attention_layernorm(hidden_states)
62
+ hidden_states = self.moe(hidden_states, element_states=element_states)
63
+ hidden_states = residual + hidden_states
64
+
65
+ return hidden_states
66
+
67
+
68
+ class AetherNetModel(nn.Module):
69
+ """AETHER-Net Language Model.
70
+
71
+ Architecture:
72
+ - Embedding โ†’ 25 ร— AetherNetBlock โ†’ RMSNorm โ†’ LM Head
73
+ - Blocks arranged in 5ร—5 Latin orthogonal magic square
74
+ - Oheng MoE with ์ƒ์ƒ generate and ์ƒ๊ทน overcome connections
75
+ - Element states flow between element groups for structural self-verification
76
+ """
77
+
78
+ def __init__(self, config: AetherNetConfig):
79
+ super().__init__()
80
+ self.config = config
81
+
82
+ # Token embedding
83
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
84
+
85
+ # 25 transformer blocks
86
+ self.layers = nn.ModuleList([
87
+ AetherNetBlock(config, layer_idx=i)
88
+ for i in range(config.num_layers)
89
+ ])
90
+
91
+ # Final norm
92
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
93
+
94
+ # LM Head
95
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
96
+
97
+ # Weight tying
98
+ if config.tie_word_embeddings:
99
+ self.lm_head.weight = self.embed_tokens.weight
100
+
101
+ # Initialize
102
+ self.apply(self._init_weights)
103
+
104
+ def _init_weights(self, module):
105
+ if isinstance(module, nn.Linear):
106
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
107
+ if module.bias is not None:
108
+ nn.init.zeros_(module.bias)
109
+ elif isinstance(module, nn.Embedding):
110
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
111
+
112
+ def forward(
113
+ self,
114
+ input_ids: Optional[torch.LongTensor] = None,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ position_ids: Optional[torch.LongTensor] = None,
117
+ labels: Optional[torch.LongTensor] = None,
118
+ encoder_hidden_states: Optional[torch.Tensor] = None,
119
+ ) -> Dict[str, torch.Tensor]:
120
+ B, L = input_ids.shape
121
+
122
+ # Position IDs
123
+ if position_ids is None:
124
+ position_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
125
+
126
+ # Embed
127
+ hidden_states = self.embed_tokens(input_ids)
128
+
129
+ # โ”€โ”€ Element state tracking for Oheng connections โ”€โ”€
130
+ # Each element group accumulates its output for ์ƒ์ƒ/์ƒ๊ทน routing
131
+ element_states: Dict[str, torch.Tensor] = {}
132
+ element_layer_counts: Dict[str, int] = {e: 0 for e in ELEMENTS}
133
+
134
+ # โ”€โ”€ Forward through 25 layers โ”€โ”€
135
+ for i, layer in enumerate(self.layers):
136
+ element = LAYER_TO_ELEMENT[i]
137
+
138
+ hidden_states = layer(
139
+ hidden_states,
140
+ attention_mask=attention_mask,
141
+ position_ids=position_ids,
142
+ element_states=element_states,
143
+ encoder_hidden_states=encoder_hidden_states,
144
+ )
145
+
146
+ # Update element state (running average of this element's layer outputs)
147
+ element_layer_counts[element] += 1
148
+ count = element_layer_counts[element]
149
+ if element in element_states:
150
+ # Exponential moving average of element's outputs
151
+ element_states[element] = (
152
+ element_states[element] * (count - 1) / count
153
+ + hidden_states.detach() / count
154
+ )
155
+ else:
156
+ element_states[element] = hidden_states.detach()
157
+
158
+ # Final norm
159
+ hidden_states = self.norm(hidden_states)
160
+
161
+ # LM Head
162
+ logits = self.lm_head(hidden_states)
163
+
164
+ # Loss
165
+ loss = None
166
+ if labels is not None:
167
+ shift_logits = logits[..., :-1, :].contiguous()
168
+ shift_labels = labels[..., 1:].contiguous()
169
+ loss = nn.functional.cross_entropy(
170
+ shift_logits.view(-1, self.config.vocab_size),
171
+ shift_labels.view(-1),
172
+ ignore_index=-100,
173
+ )
174
+
175
+ return {
176
+ "loss": loss,
177
+ "logits": logits,
178
+ "element_states": element_states,
179
+ }
180
+
181
+ def count_parameters(self) -> Dict[str, int]:
182
+ """Count parameters by component."""
183
+ counts = {
184
+ "embedding": sum(p.numel() for p in self.embed_tokens.parameters()),
185
+ "lm_head": sum(p.numel() for p in self.lm_head.parameters()),
186
+ "norm": sum(p.numel() for p in self.norm.parameters()),
187
+ }
188
+
189
+ attn_total = 0
190
+ moe_total = 0
191
+ generate_total = 0
192
+ overcome_total = 0
193
+
194
+ for layer in self.layers:
195
+ attn_total += sum(p.numel() for p in layer.attention.parameters())
196
+ attn_total += sum(p.numel() for p in layer.input_layernorm.parameters())
197
+ attn_total += sum(p.numel() for p in layer.post_attention_layernorm.parameters())
198
+
199
+ moe_total += sum(p.numel() for p in layer.moe.experts.parameters())
200
+ moe_total += sum(p.numel() for p in layer.moe.shared_expert.parameters())
201
+ moe_total += sum(p.numel() for p in layer.moe.router.parameters())
202
+
203
+ if layer.moe.generate_boost is not None:
204
+ generate_total += sum(p.numel() for p in layer.moe.generate_boost.parameters())
205
+ if layer.moe.overcome_gate is not None:
206
+ overcome_total += sum(p.numel() for p in layer.moe.overcome_gate.parameters())
207
+
208
+ counts["attention_layers"] = attn_total
209
+ counts["moe_experts"] = moe_total
210
+ counts["oheng_generate"] = generate_total
211
+ counts["oheng_overcome"] = overcome_total
212
+ counts["total"] = sum(counts.values())
213
+
214
+ return counts
215
+
216
+ def get_layer_map(self) -> List[Dict]:
217
+ """Return human-readable layer map for diagnostics."""
218
+ result = []
219
+ for i, layer in enumerate(self.layers):
220
+ result.append({
221
+ "layer": i,
222
+ "type": layer.layer_type,
223
+ "element": layer.element,
224
+ "element_idx": ELEMENTS.index(layer.element),
225
+ "phase": i % 5,
226
+ "attn_class": layer.attention.__class__.__name__,
227
+ })
228
+ return result
oheng_moe.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Oheng (ไบ”่กŒ) Mixture-of-Experts Router
3
+
4
+ Core innovation: 25 experts organized in 5 element groups with:
5
+ - ์ƒ์ƒ (Generate) cycle: Woodโ†’Fireโ†’Earthโ†’Metalโ†’Waterโ†’Wood
6
+ Previous element's output provides residual boost to next element.
7
+ - ์ƒ๊ทน (Overcome) cycle: WoodโŠฃEarth, EarthโŠฃWater, WaterโŠฃFire, FireโŠฃMetal, MetalโŠฃWood
8
+ Opposing element provides critic gating to suppress hallucinations.
9
+ - Loss-Free Balancing via dynamic expert bias (DeepSeek-style)
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from typing import Dict, Optional, Tuple
15
+
16
+ from config import (
17
+ ELEMENTS, GENERATE, GENERATE_REVERSE, OVERCOME, OVERCOME_REVERSE,
18
+ ELEMENT_EXPERTS, LAYER_TO_ELEMENT,
19
+ )
20
+
21
+
22
+ class Expert(nn.Module):
23
+ """Single SwiGLU expert (split from donor MLP)."""
24
+
25
+ def __init__(self, hidden_size: int, intermediate_size: int):
26
+ super().__init__()
27
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
28
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
29
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
30
+
31
+ def forward(self, x):
32
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
33
+
34
+
35
+ class SharedExpert(nn.Module):
36
+ """Shared expert that processes all tokens (always active)."""
37
+
38
+ def __init__(self, hidden_size: int, intermediate_size: int):
39
+ super().__init__()
40
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
41
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
42
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
43
+
44
+ def forward(self, x):
45
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
46
+
47
+
48
+ class GenerateBoost(nn.Module):
49
+ """์ƒ์ƒ (Generate) mechanism: Previous element boosts current element.
50
+
51
+ Woodโ†’Fireโ†’Earthโ†’Metalโ†’Waterโ†’Wood
52
+ Implemented as learnable soft scalar ฮฑ gating on the previous
53
+ element group's pooled expert state.
54
+ """
55
+
56
+ def __init__(self, hidden_size: int, num_elements: int = 5):
57
+ super().__init__()
58
+ # One learnable ฮฑ per element
59
+ self.alpha = nn.Parameter(torch.full((num_elements,), 0.1))
60
+ # Lightweight projection for sourceโ†’target mapping
61
+ self.proj = nn.Linear(hidden_size, hidden_size, bias=False)
62
+ nn.init.zeros_(self.proj.weight) # Start with zero boost
63
+
64
+ def forward(self, hidden: torch.Tensor, source_state: Optional[torch.Tensor],
65
+ element_idx: int) -> torch.Tensor:
66
+ """
67
+ Args:
68
+ hidden: Current hidden states [B, L, D]
69
+ source_state: Previous element group's output [B, L, D] or None
70
+ element_idx: Index of current element (0=wood, 1=fire, ...)
71
+ Returns:
72
+ Boosted hidden states
73
+ """
74
+ if source_state is None:
75
+ return hidden
76
+
77
+ alpha = torch.sigmoid(self.alpha[element_idx])
78
+ boost = self.proj(source_state)
79
+ return hidden + alpha * boost
80
+
81
+
82
+ class OvercomeGate(nn.Module):
83
+ """์ƒ๊ทน (Overcome) mechanism: Opposing element provides critic gating.
84
+
85
+ WoodโŠฃEarth, EarthโŠฃWater, WaterโŠฃFire, FireโŠฃMetal, MetalโŠฃWood
86
+
87
+ A lightweight critic head from the opposing element group produces a
88
+ gate that suppresses potentially erroneous activations. This is the
89
+ structural self-verification mechanism that reduces hallucination.
90
+ """
91
+
92
+ def __init__(self, hidden_size: int, critic_hidden: int = 256, num_elements: int = 5):
93
+ super().__init__()
94
+ # One critic head per element pair
95
+ self.critics = nn.ModuleList([
96
+ nn.Sequential(
97
+ nn.Linear(hidden_size, critic_hidden, bias=False),
98
+ nn.SiLU(),
99
+ nn.Linear(critic_hidden, hidden_size, bias=False),
100
+ )
101
+ for _ in range(num_elements)
102
+ ])
103
+ # Initialize to near-identity (gate โ‰ˆ 1.0 at start)
104
+ for critic in self.critics:
105
+ nn.init.zeros_(critic[-1].weight)
106
+
107
+ def forward(self, hidden: torch.Tensor, critic_source: Optional[torch.Tensor],
108
+ element_idx: int) -> torch.Tensor:
109
+ """
110
+ Args:
111
+ hidden: Current hidden states [B, L, D]
112
+ critic_source: Opposing element's output [B, L, D] or None
113
+ element_idx: Index of current element
114
+ Returns:
115
+ Gated hidden states
116
+ """
117
+ if critic_source is None:
118
+ return hidden
119
+
120
+ gate = torch.sigmoid(self.critics[element_idx](critic_source))
121
+ return hidden * gate
122
+
123
+
124
+ class OhengRouter(nn.Module):
125
+ """Top-K router with Loss-Free Balancing.
126
+
127
+ Routes tokens to experts within the current element group first,
128
+ then allows overflow to adjacent groups via generate connections.
129
+ """
130
+
131
+ def __init__(self, config):
132
+ super().__init__()
133
+ self.num_experts = config.num_experts
134
+ self.top_k = config.top_k
135
+ self.jitter_eps = config.moe_jitter_eps
136
+
137
+ # Router: hidden โ†’ expert scores
138
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
139
+
140
+ # Loss-Free Balancing bias (DeepSeek-style, not trained by gradient)
141
+ self.register_buffer(
142
+ "expert_bias",
143
+ torch.zeros(config.num_experts),
144
+ persistent=True
145
+ )
146
+
147
+ # Running load tracker for bias update
148
+ self.register_buffer(
149
+ "expert_load_ema",
150
+ torch.ones(config.num_experts) / config.num_experts,
151
+ persistent=False
152
+ )
153
+
154
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
155
+ """
156
+ Args:
157
+ hidden_states: [B*L, D]
158
+ Returns:
159
+ expert_indices: [B*L, top_k] โ€” indices of selected experts
160
+ expert_weights: [B*L, top_k] โ€” softmax weights (from unbiased scores)
161
+ router_logits: [B*L, num_experts] โ€” raw logits for auxiliary logging
162
+ """
163
+ # Raw scores
164
+ logits = self.gate(hidden_states) # [B*L, E]
165
+
166
+ # Add jitter during training for exploration
167
+ if self.training and self.jitter_eps > 0:
168
+ noise = torch.empty_like(logits).uniform_(1.0 - self.jitter_eps, 1.0 + self.jitter_eps)
169
+ logits = logits * noise
170
+
171
+ # Biased scores for selection (Loss-Free Balancing)
172
+ biased_logits = logits + self.expert_bias.unsqueeze(0)
173
+
174
+ # Top-K selection on biased scores
175
+ topk_biased, indices = torch.topk(biased_logits, self.top_k, dim=-1)
176
+
177
+ # Weights from UNBIASED scores (clean gradients)
178
+ topk_logits = torch.gather(logits, 1, indices)
179
+ weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
180
+
181
+ # Update bias (outside gradient, after each batch)
182
+ if self.training:
183
+ self._update_bias(indices)
184
+
185
+ return indices, weights, logits
186
+
187
+ @torch.no_grad()
188
+ def _update_bias(self, indices: torch.Tensor, momentum: float = 0.99, step: float = 0.001):
189
+ """Update expert bias based on current batch load."""
190
+ flat = indices.view(-1)
191
+ counts = torch.bincount(flat, minlength=self.num_experts).float()
192
+ load = counts / max(counts.sum().item(), 1.0)
193
+
194
+ self.expert_load_ema.mul_(momentum).add_(load, alpha=1 - momentum)
195
+
196
+ # Increase bias for underloaded experts, decrease for overloaded
197
+ target = 1.0 / self.num_experts
198
+ self.expert_bias.add_((target - self.expert_load_ema) * step)
199
+
200
+
201
+ class OhengMoE(nn.Module):
202
+ """Complete Oheng MoE layer with Generate, Overcome, and expert computation.
203
+
204
+ Architecture per layer:
205
+ 1. Router selects top-K experts
206
+ 2. Selected experts process tokens
207
+ 3. Shared expert processes all tokens
208
+ 4. Generate boost from previous element group
209
+ 5. Overcome gate from opposing element group
210
+ 6. Sum all outputs
211
+ """
212
+
213
+ def __init__(self, config, layer_idx: int):
214
+ super().__init__()
215
+ self.layer_idx = layer_idx
216
+ self.element = LAYER_TO_ELEMENT[layer_idx]
217
+ self.element_idx = ELEMENTS.index(self.element)
218
+ self.hidden_size = config.hidden_size
219
+ self.top_k = config.top_k
220
+
221
+ # 25 routed experts
222
+ self.experts = nn.ModuleList([
223
+ Expert(config.hidden_size, config.expert_intermediate_size)
224
+ for _ in range(config.num_experts)
225
+ ])
226
+
227
+ # Shared expert (always active)
228
+ self.shared_expert = SharedExpert(config.hidden_size, config.expert_intermediate_size)
229
+
230
+ # Router
231
+ self.router = OhengRouter(config)
232
+
233
+ # Generate boost (์ƒ์ƒ)
234
+ if config.use_generate_boost:
235
+ self.generate_boost = GenerateBoost(config.hidden_size)
236
+ else:
237
+ self.generate_boost = None
238
+
239
+ # Overcome gate (์ƒ๊ทน)
240
+ if config.use_overcome_gate:
241
+ self.overcome_gate = OvercomeGate(config.hidden_size, config.overcome_gate_hidden)
242
+ else:
243
+ self.overcome_gate = None
244
+
245
+ def forward(self, hidden_states: torch.Tensor,
246
+ element_states: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
247
+ """
248
+ Args:
249
+ hidden_states: [B, L, D]
250
+ element_states: dict mapping element names to their latest output
251
+ Returns:
252
+ output: [B, L, D]
253
+ """
254
+ B, L, D = hidden_states.shape
255
+ flat = hidden_states.view(-1, D) # [B*L, D]
256
+
257
+ # Route
258
+ indices, weights, _ = self.router(flat) # [B*L, K], [B*L, K]
259
+
260
+ # Expert computation
261
+ expert_out = torch.zeros_like(flat)
262
+ for k in range(self.top_k):
263
+ expert_idx = indices[:, k] # [B*L]
264
+ expert_w = weights[:, k].unsqueeze(-1) # [B*L, 1]
265
+
266
+ for e_id in range(len(self.experts)):
267
+ mask = (expert_idx == e_id)
268
+ if mask.any():
269
+ token_input = flat[mask]
270
+ token_output = self.experts[e_id](token_input)
271
+ expert_out[mask] += expert_w[mask] * token_output
272
+
273
+ # Shared expert (always active)
274
+ shared_out = self.shared_expert(flat)
275
+
276
+ output = (expert_out + shared_out).view(B, L, D)
277
+
278
+ # Apply Oheng connections if element states available
279
+ if element_states is not None:
280
+ # ์ƒ์ƒ Generate boost
281
+ if self.generate_boost is not None:
282
+ gen_source_elem = GENERATE_REVERSE.get(self.element)
283
+ gen_source = element_states.get(gen_source_elem)
284
+ output = self.generate_boost(output, gen_source, self.element_idx)
285
+
286
+ # ์ƒ๊ทน Overcome gate
287
+ if self.overcome_gate is not None:
288
+ overcome_source_elem = OVERCOME_REVERSE.get(self.element)
289
+ overcome_source = element_states.get(overcome_source_elem)
290
+ output = self.overcome_gate(output, overcome_source, self.element_idx)
291
+
292
+ return output
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ safetensors>=0.4.0
3
+ gradio>=5.0.0
4
+ transformers>=4.45.0
5
+ huggingface-hub>=0.25.0