oars344/attnres-phase1

Phase 1 checkpoint of the AttnRes language model - a decoder-only transformer that replaces standard residual connections with Block Attention Residuals (learned, input-dependent softmax attention over preceding block representations). See Attention Residuals (Chen et al., Moonshot AI).

Architecture

hidden_size: 768
num_layers: 12
num_attention_heads: 12
num_key_value_heads: 4
vocab_size: 50304
use_attn_res: True
sublayers_per_block: 2

Total parameters: 114.2M

Usage

This repo is auto-registered in transformers.AutoModelForCausalLM via our AttnResLMForCausalLM wrapper (src/model/hf_wrapper.py). When you load it through any of our training / inference scripts, the wrapper's AutoConfig.register("attnres", ...) call has already run, so a plain from_pretrained(repo_id) call just works - no trust_remote_code=True required and no modeling files uploaded to the Hub. For fully external usage from a fresh Python session that hasn't imported our wrapper, pass trust_remote_code=True (or import src.model.hf_wrapper once before loading):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

repo_id = "oars344/attnres-phase1"

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(
    repo_id,
    torch_dtype=torch.bfloat16,    # 114M params -> ~228 MB in bf16
    device_map="auto",
)
model.eval()

prompt = "Once upon a time"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=32, do_sample=True, top_p=0.95)

print(tokenizer.decode(output[0], skip_special_tokens=True))

Inspecting the raw AttnResLM

The HF wrapper stores the underlying AttnResLM under model.model. The learned pseudo-query projections inside each block's BlockAttnRes live at model.model.layers[i].{attn_res,mlp_res}.proj:

# Probing the learned residual aggregation weights:
projection = model.model.layers[0].attn_res.proj.weight  # [1, hidden_size]
print(f"Layer-0 attn_res pseudo-query shape: {tuple(projection.shape)}")
# `mlp_res.proj` is the corresponding projection for the MLP-side residual.

Phase 2 fine-tuning with LoRA / QLoRA

Adapters can be trained on top of this base via src/training/train_phase2.py. AttnRes-aware target modules include the seven standard Llama-style linears plus the BlockAttnRes pseudo-query projections (attn_res.proj, mlp_res.proj), so LoRA can re-route the residual stream for downstream tasks:

from peft import LoraConfig, get_peft_model

attnres_targets = [
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj",
    "attn_res.proj", "mlp_res.proj",
]
model = get_peft_model(
    model,
    LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM",
               target_modules=attnres_targets),
)
model.print_trainable_parameters()
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support