| | """ |
| | Hybrid ASPP-Attention Architecture (Asterisk Model) |
| | Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms |
| | to enhance model expressiveness while maintaining efficiency. |
| | |
| | Architecture Design: |
| | - Hybrid layers: Standard attention + ASPP operator in parallel |
| | - Gate mechanism for dynamic fusion |
| | - Knowledge distillation from SmolLM2-135M base model |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel |
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaAttention, |
| | LlamaDecoderLayer, |
| | LlamaRMSNorm, |
| | LlamaMLP, |
| | ) |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| | from typing import Optional, Tuple, List |
| |
|
| |
|
| | class AsteriskConfig(LlamaConfig): |
| | """ |
| | Configuration class for Asterisk model. |
| | Inherits from LlamaConfig with custom model_type. |
| | """ |
| | model_type = "asterisk" |
| |
|
| | def __init__( |
| | self, |
| | hybrid_layer_indices: Optional[List[int]] = None, |
| | aspp_hidden_dim: Optional[int] = None, |
| | aspp_num_steps: int = 2, |
| | aspp_dropout: float = 0.1, |
| | aspp_num_neighbors: int = 1, |
| | |
| | pi_flow: bool = False, |
| | pi_flow_steps: int = 1, |
| | pi_flow_scale: float = 0.2, |
| | pi_flow_use_gate: bool = True, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.hybrid_layer_indices = hybrid_layer_indices |
| | self.aspp_hidden_dim = aspp_hidden_dim |
| | self.aspp_num_steps = aspp_num_steps |
| | self.aspp_dropout = aspp_dropout |
| | self.aspp_num_neighbors = aspp_num_neighbors |
| | |
| | self.pi_flow = pi_flow |
| | self.pi_flow_steps = pi_flow_steps |
| | self.pi_flow_scale = pi_flow_scale |
| | self.pi_flow_use_gate = pi_flow_use_gate |
| |
|
| |
|
| | class ASPPOperator(nn.Module): |
| | """ |
| | Asterisk Operator (ASPP) - Union-Find Graph Propagation |
| | |
| | Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections: |
| | - Each position maintains a parent pointer: parent[i] |
| | - Initial structure: parent[i] = max(0, i-1) (linear chain) |
| | - Message passing: aggregate self + parent features |
| | - Can apply path compression for optimization |
| | |
| | Advantages: |
| | - O(n) complexity with simple indexing |
| | - Dynamic grouping of related positions |
| | - Efficient parent-only propagation (no complex gather) |
| | - Nearly constant time find with path compression |
| | |
| | Complexity: O(n) with α(n) ≈ O(1) per operation |
| | Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i]) |
| | |
| | Args: |
| | hidden_size: Dimension of hidden states (input/output) |
| | aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size) |
| | num_steps: Number of evolution steps K (default: 2) |
| | dropout: Dropout rate for regularization (default: 0.1) |
| | num_neighbors: Fixed at 1 (only parent) for Union-Find structure |
| | """ |
| |
|
| | def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1): |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.aspp_hidden_dim = aspp_hidden_dim or hidden_size |
| | self.num_steps = num_steps |
| | self.num_neighbors = 1 |
| |
|
| | |
| | self.use_projection = (self.aspp_hidden_dim != hidden_size) |
| | if self.use_projection: |
| | self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim) |
| | self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size) |
| | self.proj_dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.message_net = nn.Sequential( |
| | nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2), |
| | nn.SiLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim), |
| | nn.Dropout(dropout), |
| | ) |
| |
|
| | |
| | self.k_logit = nn.Parameter(torch.tensor(1.0)) |
| |
|
| | |
| | self.residual_scale = nn.Parameter(torch.tensor(0.1)) |
| |
|
| | |
| | self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5) |
| |
|
| | def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor: |
| | """ |
| | Compute parent index for each position using Union-Find structure |
| | |
| | Simple implementation: parent[i] = i-1 (linear chain) |
| | - Position 0 points to itself (root) |
| | - All others point to previous position |
| | |
| | Can be extended with dynamic union operations based on: |
| | - Semantic similarity |
| | - Positional heuristics |
| | - Learned grouping |
| | |
| | Returns: [seq_len] tensor of parent indices |
| | """ |
| | |
| | parent_indices = torch.arange(seq_len, device=device) - 1 |
| | parent_indices[0] = 0 |
| | parent_indices = torch.clamp(parent_indices, 0, seq_len - 1) |
| |
|
| | return parent_indices |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states: [batch_size, seq_len, hidden_size] |
| | Returns: |
| | evolved_states: [batch_size, seq_len, hidden_size] |
| | """ |
| | batch_size, seq_len, _ = hidden_states.shape |
| |
|
| | |
| | if self.use_projection: |
| | h_t = self.down_proj(hidden_states) |
| | h_t = self.proj_dropout(h_t) |
| | else: |
| | h_t = hidden_states |
| |
|
| | |
| | k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps)) |
| |
|
| | |
| | for t in range(k_steps): |
| | |
| | parent_indices = self.compute_parent_indices(seq_len, h_t.device) |
| |
|
| | |
| | |
| | |
| | parent_features = h_t[:, parent_indices, :] |
| |
|
| | |
| | message_input = torch.cat([h_t, parent_features], dim=-1) |
| | h_t_next = self.message_net(message_input) |
| |
|
| | |
| | h_t = h_t + self.residual_scale * h_t_next |
| | h_t = self.norm(h_t) |
| |
|
| | |
| | if self.use_projection: |
| | h_t = self.up_proj(h_t) |
| | h_t = self.proj_dropout(h_t) |
| |
|
| | return h_t |
| |
|
| |
|
| | class HybridASPPAttentionLayer(LlamaDecoderLayer): |
| | """ |
| | Hybrid layer combining ASPP operator and standard attention |
| | Inherits from LlamaDecoderLayer to maintain compatibility |
| | |
| | Architecture: |
| | 1. Parallel branches: |
| | - ASPP operator for local structured reasoning |
| | - Standard LlamaAttention for global context |
| | 2. Gated fusion of both outputs |
| | 3. π-flow refinement (optional, per-layer) |
| | 4. Feed-forward network |
| | """ |
| |
|
| | def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1): |
| | |
| | super().__init__(config, layer_idx) |
| |
|
| | |
| | self.aspp_operator = ASPPOperator( |
| | hidden_size=config.hidden_size, |
| | aspp_hidden_dim=aspp_hidden_dim, |
| | num_steps=aspp_num_steps, |
| | dropout=aspp_dropout, |
| | num_neighbors=aspp_num_neighbors |
| | ) |
| |
|
| | |
| | self.fusion_gate = nn.Sequential( |
| | nn.Linear(config.hidden_size * 2, config.hidden_size), |
| | nn.Dropout(aspp_dropout), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | self.fusion_gate[0].bias.fill_(0.0) |
| |
|
| | |
| | if getattr(config, 'pi_flow', False): |
| | self.pi_flow_aspp = ASPPOperator( |
| | hidden_size=config.hidden_size, |
| | aspp_hidden_dim=aspp_hidden_dim, |
| | num_steps=aspp_num_steps, |
| | dropout=aspp_dropout, |
| | num_neighbors=aspp_num_neighbors |
| | ) |
| |
|
| | |
| | self.pi_flow_scale = nn.Parameter( |
| | torch.tensor(getattr(config, 'pi_flow_scale', 0.2)) |
| | ) |
| |
|
| | |
| | if getattr(config, 'pi_flow_use_gate', True): |
| | self.pi_flow_gate = nn.Sequential( |
| | nn.Linear(config.hidden_size, config.hidden_size // 4), |
| | nn.SiLU(), |
| | nn.Dropout(aspp_dropout), |
| | nn.Linear(config.hidden_size // 4, 1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values = None, |
| | use_cache: Optional[bool] = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | """ |
| | Override LlamaDecoderLayer.forward to add ASPP branch and π-flow |
| | Returns single tensor like LlamaDecoderLayer |
| | """ |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | |
| | aspp_output = self.aspp_operator(hidden_states) |
| |
|
| | |
| | attn_output, _ = self.self_attn( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | cache_position=cache_position, |
| | position_embeddings=position_embeddings, |
| | ) |
| |
|
| | |
| | fusion_input = torch.cat([aspp_output, attn_output], dim=-1) |
| | gate = self.fusion_gate(fusion_input) |
| |
|
| | |
| | fused_output = gate * aspp_output + (1 - gate) * attn_output |
| |
|
| | |
| | hidden_states = residual + fused_output |
| |
|
| | |
| | if hasattr(self, 'pi_flow_aspp'): |
| | pi_flow_steps = getattr(self.config if hasattr(self, 'config') else kwargs.get('config'), 'pi_flow_steps', 1) |
| |
|
| | for step in range(pi_flow_steps): |
| | |
| | v = self.pi_flow_aspp(hidden_states) |
| |
|
| | |
| | if hasattr(self, 'pi_flow_gate'): |
| | gate = self.pi_flow_gate(hidden_states) |
| | alpha = self.pi_flow_scale * gate |
| | else: |
| | alpha = self.pi_flow_scale |
| |
|
| | |
| | hidden_states = hidden_states + alpha * v |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | return hidden_states |
| |
|
| |
|
| | class AsteriskLlamaModel(LlamaModel): |
| | """ |
| | Asterisk-Llama model with full hybrid ASPP-Attention architecture |
| | |
| | All layers use hybrid ASPP+Attention by default for maximum expressiveness. |
| | """ |
| |
|
| | def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): |
| | super().__init__(config) |
| |
|
| | |
| | if hybrid_layer_indices is None: |
| | |
| | num_layers = config.num_hidden_layers |
| | hybrid_layer_indices = list(range(num_layers)) |
| |
|
| | self.hybrid_layer_indices = hybrid_layer_indices |
| |
|
| | |
| | for idx in hybrid_layer_indices: |
| | if idx < len(self.layers): |
| | self.layers[idx] = HybridASPPAttentionLayer( |
| | config, |
| | layer_idx=idx, |
| | aspp_hidden_dim=aspp_hidden_dim, |
| | aspp_num_steps=aspp_num_steps, |
| | aspp_dropout=aspp_dropout, |
| | aspp_num_neighbors=aspp_num_neighbors |
| | ) |
| |
|
| | |
| | self.post_init() |
| |
|
| |
|
| | class AsteriskForCausalLM(LlamaForCausalLM): |
| | """ |
| | Asterisk Causal LM with Hybrid ASPP-Attention architecture |
| | |
| | Registered as: AsteriskForCausalLM |
| | """ |
| |
|
| | config_class = AsteriskConfig |
| |
|
| | def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): |
| | |
| | if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'): |
| | hybrid_layer_indices = config.hybrid_layer_indices |
| | if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'): |
| | aspp_hidden_dim = config.aspp_hidden_dim |
| | if hasattr(config, 'aspp_num_steps'): |
| | aspp_num_steps = config.aspp_num_steps |
| | if hasattr(config, 'aspp_dropout'): |
| | aspp_dropout = config.aspp_dropout |
| | if hasattr(config, 'aspp_num_neighbors'): |
| | aspp_num_neighbors = config.aspp_num_neighbors |
| |
|
| | super().__init__(config) |
| |
|
| | |
| | self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) |
| |
|
| | |
| | self.config.hybrid_layer_indices = hybrid_layer_indices |
| |
|
| | |
| | self.post_init() |
| |
|
| | @classmethod |
| | def from_pretrained_base( |
| | cls, |
| | base_model_path: str, |
| | hybrid_layer_indices: Optional[List[int]] = None, |
| | aspp_hidden_dim: Optional[int] = None, |
| | aspp_num_steps: int = 2, |
| | aspp_dropout: float = 0.1, |
| | aspp_num_neighbors: int = 1, |
| | |
| | pi_flow: bool = False, |
| | pi_flow_steps: int = 1, |
| | pi_flow_scale: float = 0.2, |
| | pi_flow_use_gate: bool = True, |
| | **kwargs |
| | ): |
| | """ |
| | Load base model and convert to Asterisk architecture |
| | |
| | Args: |
| | base_model_path: Path to base SmolLM2 model |
| | hybrid_layer_indices: Which layers to make hybrid (None for all) |
| | aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size) |
| | aspp_num_steps: Number of evolution steps K for ASPP (default: 2) |
| | aspp_dropout: Dropout rate for ASPP regularization (default: 0.1) |
| | aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent) |
| | pi_flow: Enable π-flow refinement step (default: False) |
| | pi_flow_steps: Number of flow refinement steps (default: 1) |
| | pi_flow_scale: Initial flow scale parameter (default: 0.2) |
| | pi_flow_use_gate: Use token-wise adaptive gating (default: True) |
| | """ |
| | |
| | base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs) |
| | base_config = base_model.config |
| |
|
| | |
| | asterisk_config = AsteriskConfig( |
| | **base_config.to_dict(), |
| | hybrid_layer_indices=hybrid_layer_indices, |
| | aspp_hidden_dim=aspp_hidden_dim, |
| | aspp_num_steps=aspp_num_steps, |
| | aspp_dropout=aspp_dropout, |
| | aspp_num_neighbors=aspp_num_neighbors, |
| | pi_flow=pi_flow, |
| | pi_flow_steps=pi_flow_steps, |
| | pi_flow_scale=pi_flow_scale, |
| | pi_flow_use_gate=pi_flow_use_gate, |
| | ) |
| |
|
| | |
| | asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) |
| |
|
| | |
| | asterisk_model.load_state_dict(base_model.state_dict(), strict=False) |
| |
|
| | print(f"✓ Converted base model to Asterisk architecture with Graph Propagation") |
| | print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}") |
| | aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)" |
| | print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}") |
| | if pi_flow: |
| | print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}") |
| |
|
| | return asterisk_model, base_model |
| |
|
| |
|
| | |
| | AutoConfig.register("asterisk", AsteriskConfig) |
| | AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM) |
| |
|
| |
|
| | def get_model_info(model): |
| | """Print model architecture information""" |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| |
|
| | print(f" • Total parameters: {total_params:,}") |
| | print(f" • Trainable parameters: {trainable_params:,}") |
| | print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)") |
| |
|
| | if isinstance(model, AsteriskForCausalLM): |
| | print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}") |
| | print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}") |
| |
|