DigitalDaimyo's picture
Update README.md
750810f verified
metadata
license: mit
tags:
  - pytorch
  - addressed-state-attention
  - interpretable-ai
  - mechanistic-interpretability
language:
  - en

Addressed State Attention (ASA)

Interpretable slot-based attention achieving competitive language modeling performance.

Quick Start

# Install directly from GitHub
!pip install git+https://github.com/DigitalDaimyo/AddressedStateAttention.git

from asa import load_asm_checkpoint, generate
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download

# Download checkpoint from Hugging Face
ckpt_path = hf_hub_download(
    repo_id="DigitalDaimyo/AddressedStateAttention",
    filename="checkpoints/fineweb_187M_75k.pt"
)

# Load checkpoint
model, cfg, ckpt = load_asm_checkpoint(
    ckpt_path,
    mode="analysis"
)

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Generate text
print(generate(model, tokenizer, "Once upon a time"))

Performance
FineWeb, 187M params: 3.73 val loss / 41.6 PPL (75k steps•32 batch•1024 seq)
Architecture: 21 layers, 768d, 12 heads, 16 slots
Links
Code: https://github.com/DigitalDaimyo/AddressedStateAttention
Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/paper_drafts