SwipeALot-base / modeling_swipe.py
dleemiller's picture
Upload folder using huggingface_hub
b121266 verified
"""HuggingFace-compatible model classes for SwipeTransformer."""
from dataclasses import dataclass
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_swipe import SwipeTransformerConfig
@dataclass
class SwipeTransformerOutput(ModelOutput):
"""
Output type for SwipeTransformerModel.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (character prediction).
char_logits (`torch.FloatTensor` of shape `(batch_size, char_length, vocab_size)`):
Prediction scores of the character prediction head (text segment only).
path_logits (`torch.FloatTensor` of shape `(batch_size, path_length, path_input_dim)`, *optional*):
Prediction scores of the path prediction head (path segment only, if enabled).
length_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
Predicted length from the length head (if enabled).
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
SEP token embeddings for similarity/embedding tasks.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`.
When requested, this includes the input embeddings plus one entry per encoder layer.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of attention tensors (one for each layer) of shape
`(batch_size, num_heads, sequence_length, sequence_length)`.
"""
loss: torch.FloatTensor | None = None
char_logits: torch.FloatTensor | None = None
path_logits: torch.FloatTensor | None = None
length_logits: torch.FloatTensor | None = None
last_hidden_state: torch.FloatTensor | None = None
pooler_output: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
class SwipeTransformerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface
for downloading and loading pretrained models.
"""
config_class = SwipeTransformerConfig
base_model_prefix = "swipe_transformer"
supports_gradient_checkpointing = False
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
"""
HuggingFace-compatible SwipeTransformerModel.
This model reuses the existing components from src/swipealot/models/
and wraps them in a HuggingFace-compatible interface.
Args:
config (SwipeTransformerConfig): Model configuration
"""
def __init__(self, config: SwipeTransformerConfig):
super().__init__(config)
self.config = config
# Import existing components
from .embeddings import MixedEmbedding
from .heads import CharacterPredictionHead, LengthPredictionHead, PathPredictionHead
# Embeddings
self.embeddings = MixedEmbedding(
vocab_size=config.vocab_size,
max_path_len=config.max_path_len,
max_char_len=config.max_char_len,
d_model=config.d_model,
dropout=config.dropout,
path_input_dim=config.path_input_dim,
)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.n_heads,
dim_feedforward=config.d_ff,
dropout=config.dropout,
activation="gelu",
batch_first=True,
norm_first=True, # Pre-LayerNorm
)
self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=config.n_layers,
enable_nested_tensor=False,
)
# Prediction heads
self.char_head = (
CharacterPredictionHead(
d_model=config.d_model,
vocab_size=config.vocab_size,
)
if config.predict_char
else None
)
if config.predict_path:
self.path_head = PathPredictionHead(
d_model=config.d_model, output_dim=config.path_input_dim
)
else:
self.path_head = None
# Length prediction head (predicts word length from path)
# Max length is max_char_len (including EOS)
self.length_head = (
LengthPredictionHead(d_model=config.d_model) if config.predict_length else None
)
# Initialize weights
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
path_coords: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | dict | None = None,
return_dict: bool | None = None,
output_hidden_states: bool | None = None,
output_attentions: bool | None = None,
**kwargs,
):
"""
Forward pass of the model.
Args:
input_ids (torch.Tensor): Character token IDs [batch, char_len]
path_coords (torch.Tensor): Path features [batch, path_len, path_input_dim]
Default: [batch, path_len, 6] for (x, y, dx, dy, ds, log_dt)
attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
labels (torch.Tensor or dict, optional): Labels for loss calculation
Can be tensor [batch, char_len] or dict with keys like char_labels, path_labels
return_dict (bool, optional): Whether to return ModelOutput object
output_hidden_states (bool, optional): Whether to output hidden states
output_attentions (bool, optional): Whether to output attention weights
**kwargs: Additional arguments (for compatibility)
Returns:
SwipeTransformerOutput or tuple: Model outputs with:
- loss: Optional loss value
- char_logits: Character prediction logits [batch, char_len, vocab_size] (if enabled)
- path_logits: Path prediction logits [batch, path_len, path_input_dim] (if enabled)
- length_logits: Length regression output [batch] (if enabled)
- last_hidden_state: Hidden states [batch, seq_len, d_model]
- pooler_output: SEP token embedding [batch, d_model] for similarity/embedding tasks
- hidden_states: Tuple of per-layer hidden states (if output_hidden_states=True)
- attentions: Tuple of per-layer attention weights (if output_attentions=True)
"""
# Validate required inputs
if input_ids is None or path_coords is None:
raise ValueError("Both input_ids and path_coords are required")
# Extract labels if dict (used by custom trainers)
if isinstance(labels, dict):
char_labels = labels.get("char_labels")
# Can handle other label types in the future (path_labels, etc.)
else:
char_labels = labels
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
batch_size = path_coords.shape[0]
device = path_coords.device
# Create [CLS] and [SEP] tokens
cls_token = torch.full(
(batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
)
sep_token = torch.full(
(batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
)
# Get embeddings
embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
# Prepare attention mask for encoder
if attention_mask is not None:
# Convert attention mask: 1 = attend, 0 = ignore
# PyTorch expects: False = attend, True = ignore
src_key_padding_mask = attention_mask == 0
else:
src_key_padding_mask = None
# Encode while optionally capturing attentions and per-layer hidden states.
attentions: tuple[torch.Tensor, ...] | None = None
hidden_states_by_layer: list[torch.Tensor] | None = [] if output_hidden_states else None
hooks = []
original_forwards: dict[int, callable] = {}
attentions_buffer: list[torch.Tensor | None] | None = None
def make_patched_forward(original_forward):
def patched_forward(
query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=False,
is_causal=False,
):
return original_forward(
query,
key,
value,
key_padding_mask=key_padding_mask,
need_weights=True,
attn_mask=attn_mask,
average_attn_weights=False,
is_causal=is_causal,
)
return patched_forward
def make_hook(layer_idx: int):
def hook(_module: nn.Module, _input: tuple, output: tuple):
if (
attentions_buffer is not None
and isinstance(output, tuple)
and len(output) > 1
and output[1] is not None
):
attentions_buffer[layer_idx] = output[1]
return hook
if output_attentions:
attentions_buffer = [None] * len(self.encoder.layers)
for idx, layer in enumerate(self.encoder.layers):
attn_module = layer.self_attn
original_forwards[idx] = attn_module.forward
attn_module.forward = make_patched_forward(original_forwards[idx])
hooks.append(attn_module.register_forward_hook(make_hook(idx)))
try:
x = embeddings
for layer in self.encoder.layers:
x = layer(x, src_key_padding_mask=src_key_padding_mask)
if hidden_states_by_layer is not None:
hidden_states_by_layer.append(x)
hidden_states = x
if attentions_buffer is not None:
if any(a is None for a in attentions_buffer):
missing = [i for i, a in enumerate(attentions_buffer) if a is None]
raise RuntimeError(
f"Failed to capture attention weights for layers: {missing}."
)
attentions = tuple(attentions_buffer) # type: ignore[assignment]
finally:
for hook in hooks:
hook.remove()
for idx, layer in enumerate(self.encoder.layers):
if idx in original_forwards:
layer.self_attn.forward = original_forwards[idx]
path_len = path_coords.shape[1]
char_len = input_ids.shape[1]
# Character prediction (text segment only)
char_logits = None
if self.char_head is not None:
# Sequence is: [CLS] + path + [SEP] + chars
char_start = 1 + path_len + 1
char_hidden = hidden_states[:, char_start : char_start + char_len, :]
char_logits = self.char_head(char_hidden)
# Path prediction (path segment only, if enabled)
path_logits = None
if self.path_head is not None:
path_hidden = hidden_states[:, 1 : 1 + path_len, :]
path_logits = self.path_head(path_hidden)
# Length prediction from CLS token
cls_hidden = hidden_states[:, 0, :] # [batch, d_model] - CLS at position 0
length_logits = self.length_head(cls_hidden) if self.length_head is not None else None
# Extract SEP token embedding for pooler output (embeddings/similarity tasks)
# SEP is at position 1 + path_len
sep_position = 1 + path_len
pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
# Compute loss if labels provided (masked-only; -100 = ignore)
loss = None
if char_labels is not None and self.char_head is not None:
# Predict only the text segment
char_pred = char_logits # [B, char_len, V]
labels_flat = char_labels.reshape(-1)
mask = labels_flat != -100
if mask.any():
logits_flat = char_pred.reshape(-1, self.config.vocab_size)[mask]
labels_flat = labels_flat[mask]
loss = nn.functional.cross_entropy(logits_flat, labels_flat, reduction="mean")
else:
loss = torch.tensor(0.0, device=hidden_states.device)
if not return_dict:
hidden_tuple = None
if hidden_states_by_layer is not None:
hidden_tuple = (embeddings,) + tuple(hidden_states_by_layer)
output = (
char_logits,
path_logits,
length_logits,
hidden_states,
pooler_output,
hidden_tuple,
attentions,
)
return (loss,) + output if loss is not None else output
all_hidden_states = None
if hidden_states_by_layer is not None:
all_hidden_states = (embeddings,) + tuple(hidden_states_by_layer)
return SwipeTransformerOutput(
loss=loss,
char_logits=char_logits,
path_logits=path_logits,
length_logits=length_logits,
last_hidden_state=hidden_states,
pooler_output=pooler_output,
hidden_states=all_hidden_states,
attentions=attentions,
)
#
# Legacy note:
# `SwipeModel` (embeddings-only) has been removed; use `SwipeTransformerModel` and read
# `outputs.pooler_output` for embeddings.