| """
|
| VortexModel: Main model class combining SSM, attention, science modules, and SciGate FFN.
|
| Implements two block types: SSM-only and attention+science+SciGate FFN.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple, List, Dict
|
|
|
| from .ssm_layer import VortexSSM
|
| from .attention_layer import VortexLocalAttention
|
| from .scigate_ffn import SciGateFFN
|
| from .science_modules import (
|
| EquationModule,
|
| NumericalReasoningModule,
|
| CitationModule,
|
| MolecularModule,
|
| )
|
|
|
|
|
| class VortexBlock(nn.Module):
|
| """
|
| Two types of blocks:
|
| 1. SSMBlock: only VortexSSM
|
| 2. AttentionBlock: VortexLocalAttention + ScienceModules + SciGateFFN
|
| """
|
|
|
| def __init__(
|
| self,
|
| config: Dict,
|
| is_ssm_block: bool = True,
|
| ):
|
| """
|
| Initialize a Vortex block.
|
|
|
| Args:
|
| config: Model configuration
|
| is_ssm_block: If True, this is an SSM-only block; else attention+science+FFN
|
| """
|
| super().__init__()
|
| self.config = config
|
| self.is_ssm_block = is_ssm_block
|
| self.d_model = config["d_model"]
|
|
|
| if is_ssm_block:
|
|
|
| self.ssm = VortexSSM(
|
| d_model=config["d_model"],
|
| d_state=config["d_state"],
|
| d_conv=config["d_conv"],
|
| )
|
| self.norm = nn.LayerNorm(config["d_model"])
|
| else:
|
|
|
| self.attn = VortexLocalAttention(
|
| d_model=config["d_model"],
|
| num_heads=config["num_heads"],
|
| window_size=config["window_size"],
|
| use_flash_attention=config.get("use_flash_attention", True),
|
| )
|
| self.attn_norm = nn.LayerNorm(config["d_model"])
|
|
|
|
|
| self.equation_module = None
|
| self.numerical_module = None
|
| self.citation_module = None
|
| self.molecular_module = None
|
|
|
| if config.get("enable_equation_module", True):
|
| self.equation_module = EquationModule(config["d_model"])
|
|
|
| if config.get("enable_numerical_module", True):
|
| self.numerical_module = NumericalReasoningModule(config["d_model"])
|
|
|
| if config.get("enable_citation_module", True):
|
| self.citation_module = CitationModule(config["d_model"])
|
|
|
| if config.get("enable_molecular_module", True):
|
| self.molecular_module = MolecularModule(config["d_model"])
|
|
|
|
|
| self.ffn = SciGateFFN(
|
| d_model=config["d_model"],
|
| expansion=config["ffn_expansion"],
|
| num_domains=config["num_domains"],
|
| )
|
| self.ffn_norm = nn.LayerNorm(config["d_model"])
|
|
|
|
|
| self.final_norm = nn.LayerNorm(config["d_model"])
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| domain_ids: Optional[torch.Tensor] = None,
|
| domain_tags: Optional[torch.Tensor] = None,
|
| text: Optional[List[str]] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass through the block.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| domain_ids: Optional domain IDs for SciGate FFN
|
| domain_tags: Optional domain tag masks
|
| text: Optional original text for science module span detection
|
| attention_mask: Optional attention mask
|
|
|
| Returns:
|
| Output tensor (batch, seq_len, d_model)
|
| """
|
| residual = x
|
|
|
| if self.is_ssm_block:
|
|
|
| x = self.norm(x)
|
| x = self.ssm(x)
|
| x = residual + x
|
| x = self.final_norm(x)
|
| else:
|
|
|
|
|
| residual_attn = x
|
| x = self.attn_norm(x)
|
| global_mask = self._detect_global_tokens(x) if hasattr(self, '_detect_global_tokens') else None
|
| x = self.attn(x, global_mask=global_mask, attention_mask=attention_mask)
|
| x = residual_attn + x
|
|
|
|
|
| if self.equation_module is not None:
|
| x = x + self.equation_module(x, text=text)
|
|
|
| if self.numerical_module is not None:
|
| x = x + self.numerical_module(x, text=text)
|
|
|
| if self.citation_module is not None:
|
| x_cited, _ = self.citation_module(x, text=text)
|
| x = x + x_cited
|
|
|
| if self.molecular_module is not None:
|
| x = x + self.molecular_module(x, text=text)
|
|
|
|
|
| residual_ffn = x
|
| x = self.ffn_norm(x)
|
| x = self.ffn(x, domain_ids=domain_ids, domain_tags=domain_tags)
|
| x = residual_ffn + x
|
|
|
| x = self.final_norm(x)
|
|
|
| return x
|
|
|
| def _detect_global_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Detect global tokens that should attend across the entire sequence.
|
| Global tokens are those with special domain tags or high norm.
|
| """
|
|
|
| norms = torch.norm(x, dim=-1)
|
| threshold = torch.quantile(norms, 0.95, dim=-1, keepdim=True)
|
| global_mask = norms > threshold
|
|
|
| return global_mask
|
|
|
|
|
| class VortexModel(nn.Module):
|
| """
|
| Main Vortex model combining SSM and attention blocks.
|
| Supports both 7B and 13B configurations.
|
| """
|
|
|
| def __init__(
|
| self,
|
| config: Dict,
|
| ):
|
| """
|
| Initialize VortexModel.
|
|
|
| Args:
|
| config: Model configuration (from vortex_7b_config.py or vortex_13b_config.py)
|
| """
|
| super().__init__()
|
| self.config = config
|
|
|
|
|
| self.embed_tokens = nn.Embedding(config["vocab_size"], config["d_model"])
|
|
|
|
|
| self.blocks = nn.ModuleList()
|
| self._build_blocks()
|
|
|
|
|
| self.ln_f = nn.LayerNorm(config["d_model"])
|
|
|
|
|
| self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
|
|
|
|
|
| self._initialize_weights()
|
|
|
| def _build_blocks(self):
|
| """Build the sequence of SSM and attention blocks."""
|
| num_layers = self.config["num_layers"]
|
| ssm_ratio = self.config["ssm_ratio"]
|
|
|
|
|
| num_ssm_blocks = int(num_layers * ssm_ratio)
|
| num_attn_blocks = num_layers - num_ssm_blocks
|
|
|
|
|
| if ssm_ratio == 0.6:
|
| pattern = [0, 0, 1]
|
|
|
| blocks = []
|
| while len(blocks) < num_layers:
|
| blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
|
| else:
|
| pattern = [0, 1]
|
| blocks = []
|
| while len(blocks) < num_layers:
|
| blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
|
|
|
|
|
| blocks = blocks[:num_layers]
|
| assert len(blocks) == num_layers
|
|
|
|
|
| for is_attn in blocks:
|
| block = VortexBlock(
|
| config=self.config,
|
| is_ssm_block=not is_attn,
|
| )
|
| self.blocks.append(block)
|
|
|
| print(f"Built {num_layers} layers: {num_ssm_blocks} SSM, {num_attn_blocks} Attention")
|
|
|
| def _initialize_weights(self):
|
| """Initialize weights."""
|
| nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
|
| for block in self.blocks:
|
| if hasattr(block, 'ssm'):
|
| block.ssm._initialize_weights()
|
| if hasattr(block, 'attn'):
|
| block.attn._initialize_weights()
|
| if hasattr(block, 'ffn'):
|
| block.ffn._initialize_weights()
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
| domain_ids: Optional[torch.Tensor] = None,
|
| domain_tags: Optional[torch.Tensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| text: Optional[List[str]] = None,
|
| return_dict: bool = True,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass through the model.
|
|
|
| Args:
|
| input_ids: Token IDs (batch, seq_len)
|
| domain_ids: Optional domain IDs
|
| domain_tags: Optional domain tag masks
|
| attention_mask: Optional attention mask (batch, seq_len)
|
| text: Optional original text for science modules
|
| return_dict: Whether to return dict (always returns tensor for now)
|
|
|
| Returns:
|
| Logits (batch, seq_len, vocab_size)
|
| """
|
|
|
| x = self.embed_tokens(input_ids)
|
|
|
|
|
| for block in self.blocks:
|
| x = block(
|
| x,
|
| domain_ids=domain_ids,
|
| domain_tags=domain_tags,
|
| text=text,
|
| attention_mask=attention_mask,
|
| )
|
|
|
|
|
| x = self.ln_f(x)
|
|
|
|
|
| logits = self.lm_head(x)
|
|
|
| if return_dict:
|
| return {"logits": logits, "last_hidden_state": x}
|
| return logits
|
|
|
| def get_num_params(self) -> int:
|
| """Get total number of parameters."""
|
| return sum(p.numel() for p in self.parameters())
|
|
|
| def get_trainable_params(self) -> int:
|
| """Get number of trainable parameters."""
|
| return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
| def estimate_memory_usage(
|
| self,
|
| batch_size: int,
|
| seq_len: int,
|
| use_gradient_checkpointing: bool = False,
|
| ) -> Dict[str, float]:
|
| """
|
| Estimate memory usage for a given batch size and sequence length.
|
|
|
| Returns:
|
| Dictionary with memory estimates in GB
|
| """
|
| params = self.get_num_params()
|
| param_bytes = params * 2
|
|
|
|
|
|
|
| activations_per_layer = batch_size * seq_len * self.config["d_model"] * 2
|
| total_activations = activations_per_layer * self.config["num_layers"]
|
|
|
|
|
| gradients = param_bytes
|
|
|
|
|
| optimizer_states = params * 2 * 2
|
|
|
| total_memory = (param_bytes + total_activations + gradients + optimizer_states) / 1e9
|
|
|
| return {
|
| "parameters_gb": param_bytes / 1e9,
|
| "activations_gb": total_activations / 1e9,
|
| "gradients_gb": gradients / 1e9,
|
| "optimizer_states_gb": optimizer_states / 1e9,
|
| "total_gb": total_memory,
|
| }
|
|
|
|
|
| def test_vortex_model():
|
| """Test the VortexModel."""
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
|
|
| config = VORTEX_7B_CONFIG.copy()
|
|
|
| config["d_model"] = 512
|
| config["num_layers"] = 4
|
| config["num_heads"] = 8
|
| config["vocab_size"] = 1000
|
|
|
| model = VortexModel(config)
|
|
|
| batch_size = 2
|
| seq_len = 128
|
| input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
|
|
|
|
| output = model(input_ids)
|
| logits = output["logits"]
|
|
|
| print(f"Model parameters: {model.get_num_params():,}")
|
| print(f"Input shape: {input_ids.shape}")
|
| print(f"Logits shape: {logits.shape}")
|
| assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
|
|
|
|
| mem = model.estimate_memory_usage(batch_size, seq_len)
|
| print(f"Memory estimate for batch={batch_size}, seq_len={seq_len}:")
|
| for k, v in mem.items():
|
| print(f" {k}: {v:.2f} GB")
|
|
|
| print("VortexModel test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_vortex_model()
|
|
|