oars344/attnres-phase1
Phase 1 checkpoint of the AttnRes language model - a decoder-only
transformer that replaces standard residual connections with Block
Attention Residuals (learned, input-dependent softmax attention over
preceding block representations). See Attention Residuals (Chen
et al., Moonshot AI).
Architecture
hidden_size: 768
num_layers: 12
num_attention_heads: 12
num_key_value_heads: 4
vocab_size: 50304
use_attn_res: True
sublayers_per_block: 2
Total parameters: 114.2M
Usage
This repo is auto-registered in transformers.AutoModelForCausalLM via our
AttnResLMForCausalLM wrapper (src/model/hf_wrapper.py). When you load
it through any of our training / inference scripts, the wrapper's
AutoConfig.register("attnres", ...) call has already run, so a plain
from_pretrained(repo_id) call just works - no trust_remote_code=True
required and no modeling files uploaded to the Hub. For fully external
usage from a fresh Python session that hasn't imported our wrapper, pass
trust_remote_code=True (or import src.model.hf_wrapper once before
loading):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
repo_id = "oars344/attnres-phase1"
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(
repo_id,
torch_dtype=torch.bfloat16, # 114M params -> ~228 MB in bf16
device_map="auto",
)
model.eval()
prompt = "Once upon a time"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=32, do_sample=True, top_p=0.95)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Inspecting the raw AttnResLM
The HF wrapper stores the underlying AttnResLM under model.model.
The learned pseudo-query projections inside each block's
BlockAttnRes live at model.model.layers[i].{attn_res,mlp_res}.proj:
# Probing the learned residual aggregation weights:
projection = model.model.layers[0].attn_res.proj.weight # [1, hidden_size]
print(f"Layer-0 attn_res pseudo-query shape: {tuple(projection.shape)}")
# `mlp_res.proj` is the corresponding projection for the MLP-side residual.
Phase 2 fine-tuning with LoRA / QLoRA
Adapters can be trained on top of this base via
src/training/train_phase2.py. AttnRes-aware target modules include the
seven standard Llama-style linears plus the BlockAttnRes pseudo-query
projections (attn_res.proj, mlp_res.proj), so LoRA can re-route the
residual stream for downstream tasks:
from peft import LoraConfig, get_peft_model
attnres_targets = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"attn_res.proj", "mlp_res.proj",
]
model = get_peft_model(
model,
LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM",
target_modules=attnres_targets),
)
model.print_trainable_parameters()
- Downloads last month
- -