Spaces:
Sleeping
Sleeping
| """ | |
| LaM-SLidE Autoencoder for discrete token reconstruction of variable-sized | |
| entity sets (music score notes). Provides the main autoencoder model, | |
| the NoteFeatureEmbedder, and config dataclasses. | |
| Architecture: | |
| Input: (features_dict, entity_ids, mask) | |
| -> NoteFeatureEmbedder -> Encoder (cross-attn) -> fixed latent (B, L, D) | |
| -> Decoder (cross-attn) -> per-entity logits dict | |
| """ | |
| from dataclasses import dataclass, field | |
| from functools import partial | |
| from typing import Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| from .encoder import Encoder | |
| from .decoder import Decoder | |
| from .entity_embeddings import EntityEmbeddingFactorized, EntityEmbeddingOrthogonal | |
| from .note_hgt import NoteHGT | |
| class FeatureConfig: | |
| """Configuration for a single feature.""" | |
| name: str # Feature name (e.g., 'grid_position') | |
| vocab_size: int # Number of discrete tokens | |
| embed_dim: int = 32 # Embedding dimension for this feature | |
| is_input: bool = True # Use as input feature | |
| is_output: bool = True # Reconstruct this feature | |
| class AutoencoderConfig: | |
| """Configuration for the LaM-SLidE Autoencoder.""" | |
| # Feature configuration (multi-feature support) | |
| # Default: single grid_position feature for backwards compatibility | |
| features: List[FeatureConfig] = field(default_factory=lambda: [ | |
| FeatureConfig(name='grid_position', vocab_size=33, embed_dim=32), | |
| ]) | |
| # Entity identifier settings | |
| identifier_pool_size: int = 512 # Size of entity ID pool | |
| entity_embed_dim: int = 128 # Dimension of entity embeddings | |
| entity_embed_type: str = 'factorized' # 'factorized' or 'orthogonal' | |
| # Latent space | |
| dim_latent: int = 128 # Latent space dimension | |
| num_latents: int = 32 # Number of latent vectors (bottleneck) | |
| # Attention configuration | |
| dim_head_cross: int = 32 # Dimension per head in cross-attention | |
| dim_head_latent: int = 32 # Dimension per head in self-attention | |
| num_head_cross: int = 4 # Number of cross-attention heads | |
| num_head_latent: int = 4 # Number of self-attention heads | |
| # Architecture depth | |
| num_block_cross_enc: int = 2 # Cross-attention blocks in encoder | |
| num_block_attn_enc: int = 2 # Self-attention blocks in encoder | |
| num_block_cross_dec: int = 2 # Cross-attention blocks in decoder | |
| num_block_attn_dec: int = 2 # Self-attention blocks in decoder | |
| # Regularization | |
| dropout_latent: float = 0.0 # Dropout on latent vectors | |
| qk_norm: bool = True # Query-key normalization | |
| # Feature mixing MLP (applied after embedding concat, before HGT/encoder) | |
| feature_mlp_hidden_dim: int = 0 # 0 = disabled, >0 = hidden dim of feature mixing MLP | |
| # HGT (Heterogeneous Graph Transformer) for note-level message passing | |
| use_hgt: bool = False # Whether to use HGT after feature embedding | |
| hgt_num_layers: int = 2 # Number of HGT layers | |
| hgt_num_heads: int = 4 # Number of attention heads in HGT | |
| hgt_dropout: float = 0.1 # Dropout in HGT layers | |
| def input_features(self) -> List[FeatureConfig]: | |
| """Get features used as inputs.""" | |
| return [f for f in self.features if f.is_input] | |
| def output_features(self) -> List[FeatureConfig]: | |
| """Get features to reconstruct.""" | |
| return [f for f in self.features if f.is_output] | |
| def total_input_dim(self) -> int: | |
| """Total dimension of concatenated input embeddings.""" | |
| return sum(f.embed_dim for f in self.input_features) | |
| def output_vocab_sizes(self) -> Dict[str, int]: | |
| """Dict of output feature names to vocab sizes.""" | |
| return {f.name: f.vocab_size for f in self.output_features} | |
| class NoteFeatureEmbedder(nn.Module): | |
| """Embeds multiple discrete note features into continuous space and manages | |
| the shared entity identifier embeddings used by both encoder and decoder. | |
| """ | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| # Feature embeddings: each discrete feature -> continuous vector | |
| self.feature_embeddings = nn.ModuleDict() | |
| for feat in config.input_features: | |
| self.feature_embeddings[feat.name] = nn.Embedding( | |
| num_embeddings=feat.vocab_size, | |
| embedding_dim=feat.embed_dim, | |
| ) | |
| # Entity embedding: shared between encoder and decoder for traceability. | |
| # Factorized variant uses sqrt(pool_size) base+offset tables. | |
| if config.entity_embed_type == 'factorized': | |
| self.entity_embedding = EntityEmbeddingFactorized( | |
| n_entiy_embeddings=config.identifier_pool_size, | |
| embedding_dim=config.entity_embed_dim, | |
| requires_grad=True, | |
| combine='concat', # base || offset -> full embedding | |
| max_norm=1.0, | |
| ) | |
| else: | |
| self.entity_embedding = EntityEmbeddingOrthogonal( | |
| n_entiy_embeddings=config.identifier_pool_size, | |
| embedding_dim=config.entity_embed_dim, | |
| requires_grad=True, | |
| max_norm=1.0, | |
| ) | |
| # Optional HGT for note-level message passing after feature embedding | |
| self.use_hgt = config.use_hgt | |
| if config.use_hgt: | |
| self.hgt = NoteHGT( | |
| note_dim=config.total_input_dim, # Same dim as feature embeddings | |
| num_layers=config.hgt_num_layers, | |
| num_heads=config.hgt_num_heads, | |
| dropout=config.hgt_dropout, | |
| ) | |
| else: | |
| self.hgt = None | |
| # Optional feature mixing MLP applied after embedding concat, before | |
| # HGT and entity concat. Pre-norm residual (LayerNorm -> MLP + skip). | |
| if config.feature_mlp_hidden_dim > 0: | |
| act = partial(nn.GELU, approximate="tanh") | |
| self.feature_mlp = nn.Sequential( | |
| nn.LayerNorm(config.total_input_dim), | |
| nn.Linear(config.total_input_dim, config.feature_mlp_hidden_dim), | |
| act(), | |
| nn.Linear(config.feature_mlp_hidden_dim, config.total_input_dim), | |
| ) | |
| else: | |
| self.feature_mlp = None | |
| def embed_features( | |
| self, | |
| features: Dict[str, torch.Tensor], | |
| edge_dicts: Optional[List[Dict]] = None, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Embed multiple discrete features and concatenate. | |
| Optionally applies HGT message passing if enabled. | |
| Args: | |
| features: Dict of feature_name -> (batch, num_entities) tensors | |
| edge_dicts: Optional list of edge_dict per sample (required if use_hgt=True) | |
| Each edge_dict maps edge_type_tuple -> edge_index (2, E) | |
| mask: Optional (B, N) validity mask | |
| Returns: | |
| combined_embeds: (batch, num_entities, total_input_dim) | |
| """ | |
| embeddings = [] | |
| for feat in self.config.input_features: | |
| if feat.name in features: | |
| emb = self.feature_embeddings[feat.name](features[feat.name]) | |
| embeddings.append(emb) | |
| else: | |
| raise KeyError(f"Missing input feature: {feat.name}") | |
| # Concatenate all feature embeddings | |
| combined = torch.cat(embeddings, dim=-1) # (B, N, total_input_dim) | |
| # Feature mixing MLP (residual): learns cross-feature interactions | |
| if self.feature_mlp is not None: | |
| combined = combined + self.feature_mlp(combined) | |
| # Apply HGT if enabled | |
| if self.use_hgt and self.hgt is not None: | |
| if edge_dicts is None: | |
| raise ValueError("edge_dicts required when use_hgt=True") | |
| # Derive num_notes_list from mask | |
| num_notes_list = mask.sum(dim=1).tolist() if mask is not None else [combined.shape[1]] * combined.shape[0] | |
| combined = self.hgt.forward_batch(combined, edge_dicts, num_notes_list, mask=mask) | |
| return combined | |
| def embed_entities(self, entity_ids: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Embed entity identifiers to continuous vectors. | |
| Args: | |
| entity_ids: (batch, num_entities) entity identifier indices | |
| Returns: | |
| entity_embeds: (batch, num_entities, entity_embed_dim) | |
| """ | |
| return self.entity_embedding(entity_ids) | |
| def input_dim(self) -> int: | |
| """Total dimension of concatenated feature embeddings.""" | |
| return self.config.total_input_dim | |
| def entity_dim(self) -> int: | |
| """Dimension of entity embeddings.""" | |
| return self.entity_embedding.embedding_dim | |
| class LaMSLiDEAutoencoder(nn.Module): | |
| """Autoencoder for multi-feature discrete token reconstruction. | |
| Compresses variable-sized entity sets into a fixed-size latent (B, L, D) | |
| and reconstructs per-entity feature logits. Entity IDs provide a return | |
| address so the decoder can query the correct features from the latent. | |
| """ | |
| def __init__(self, config: AutoencoderConfig): | |
| super().__init__() | |
| self.config = config | |
| # Feature embedder: handles all embedding operations | |
| self.embedder = NoteFeatureEmbedder(config) | |
| # Encoder: variable-size input -> fixed-size latent | |
| self.encoder = Encoder( | |
| dim_input=self.embedder.input_dim, # Total input dimension | |
| dim_latent=config.dim_latent, | |
| dim_head_cross=config.dim_head_cross, | |
| dim_head_latent=config.dim_head_latent, | |
| num_latents=config.num_latents, | |
| num_head_cross=config.num_head_cross, | |
| num_head_latent=config.num_head_latent, | |
| num_block_cross=config.num_block_cross_enc, | |
| num_block_attn=config.num_block_attn_enc, | |
| qk_norm=config.qk_norm, | |
| entity_embedding=self.embedder.entity_embedding, # Shared! | |
| dropout_latent=config.dropout_latent, | |
| ) | |
| # Decoder: fixed-size latent -> variable-size per-entity logits | |
| self.decoder = Decoder( | |
| outputs=config.output_vocab_sizes, # Dict: feature_name -> vocab_size | |
| dim_query=config.dim_latent, | |
| dim_latent=config.dim_latent, | |
| entity_embedding=self.embedder.entity_embedding, # Shared! | |
| dim_head_cross=config.dim_head_cross, | |
| dim_head_latent=config.dim_head_latent, | |
| num_head_cross=config.num_head_cross, | |
| num_head_latent=config.num_head_latent, | |
| num_block_cross=config.num_block_cross_dec, | |
| num_block_attn=config.num_block_attn_dec, | |
| qk_norm=config.qk_norm, | |
| ) | |
| def encode( | |
| self, | |
| features: Dict[str, torch.Tensor], | |
| entity_ids: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| edge_dicts: Optional[List[Dict]] = None, | |
| ) -> torch.Tensor: | |
| """Encode variable-sized entity set to fixed-size latent (B, L, D). | |
| 1. Embed all input features -> concatenated continuous vectors | |
| 2. (Optional) Apply HGT message passing | |
| 3. Encoder concatenates feature + entity embeddings, then cross/self-attn | |
| Args: | |
| features: Dict of feature_name -> (B, N) discrete feature tensors | |
| entity_ids: (B, N) unique entity identifiers from pool | |
| mask: (B, N) boolean mask, True for valid entities | |
| edge_dicts: Optional list of edge_dict per sample (for HGT) | |
| Returns: | |
| latent: (B, L, D_latent) fixed-size latent representation | |
| """ | |
| # Embed and concatenate input features (+ optional HGT) | |
| feature_embeds = self.embedder.embed_features( | |
| features, edge_dicts=edge_dicts, mask=mask | |
| ) # (B, N, total_dim) | |
| # Encode to fixed-size latent (concat entity embeddings, cross/self attn) | |
| latent = self.encoder(feature_embeds, entity_ids, mask=mask) # (B, L, D_lat) | |
| return latent | |
| def decode( | |
| self, | |
| latent: torch.Tensor, | |
| entity_ids: torch.Tensor, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Decode latent representation back to per-entity feature logits. | |
| The decoder uses entity IDs as queries through cross-attention to | |
| retrieve feature logits from the latent space. | |
| Args: | |
| latent: (B, L, D_latent) fixed-size latent representation | |
| entity_ids: (B, N) entity identifiers to decode for | |
| Returns: | |
| outputs: Dict of feature_name -> (B, N, vocab_size) logits | |
| """ | |
| # Decoder uses entity IDs as queries to retrieve per-entity features | |
| outputs = self.decoder(latent, entity_ids) | |
| return outputs | |
| def forward( | |
| self, | |
| features: Dict[str, torch.Tensor], | |
| entity_ids: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| edge_dicts: Optional[List[Dict]] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Full forward pass: encode then decode. | |
| Args: | |
| features: Dict of feature_name -> (B, N) discrete feature tensors | |
| entity_ids: (B, N) entity identifiers | |
| mask: (B, N) validity mask for variable-sized batches | |
| edge_dicts: Optional list of edge_dict per sample (for HGT) | |
| Returns: | |
| outputs: Dict of feature_name -> (B, N, vocab_size) logits | |
| """ | |
| # Encode: features -> (B, L, D_lat) | |
| latent = self.encode( | |
| features, entity_ids, mask=mask, edge_dicts=edge_dicts | |
| ) | |
| # Decode: (B, L, D_lat) -> {feature_name: (B, N, vocab_size)} | |
| outputs = self.decode(latent, entity_ids) | |
| return outputs | |
| def count_parameters(self) -> int: | |
| """Count trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def create_autoencoder_from_dict(config_dict: Dict) -> LaMSLiDEAutoencoder: | |
| """ | |
| Create autoencoder from a dictionary config (e.g., from OmegaConf). | |
| Args: | |
| config_dict: Dictionary with model configuration | |
| Returns: | |
| Configured LaMSLiDEAutoencoder instance | |
| """ | |
| # Parse features if provided as list of dicts | |
| if 'features' in config_dict: | |
| features = [ | |
| FeatureConfig(**f) if isinstance(f, dict) else f | |
| for f in config_dict['features'] | |
| ] | |
| config_dict = {**config_dict, 'features': features} | |
| config = AutoencoderConfig(**config_dict) | |
| return LaMSLiDEAutoencoder(config) | |
| def create_single_feature_config( | |
| feature_name: str = 'grid_position', | |
| vocab_size: int = 33, | |
| embed_dim: int = 32, | |
| **kwargs, | |
| ) -> AutoencoderConfig: | |
| """ | |
| Create a config for single-feature reconstruction (backwards compatible). | |
| Args: | |
| feature_name: Name of the feature | |
| vocab_size: Number of discrete tokens | |
| embed_dim: Embedding dimension | |
| **kwargs: Additional AutoencoderConfig parameters | |
| Returns: | |
| AutoencoderConfig with single feature | |
| """ | |
| feature = FeatureConfig( | |
| name=feature_name, | |
| vocab_size=vocab_size, | |
| embed_dim=embed_dim, | |
| ) | |
| return AutoencoderConfig(features=[feature], **kwargs) | |