Gated_Linear_Attention2
This repository stores milestone checkpoints for gdn2_kla_1.3B from the
gdn2_kla_1.3B_fineweb_edu_10bt run.
What This Model Is
gdn2_kla_1.3B is a recurrent-only linear attention experiment. It starts from
Gated DeltaNet-2 and folds a Kaczmarz-style key-norm-normalized update step into
the separate erase and write gates:
\lambda_t = \frac{\eta_t}{\|k_t\|_2^2 + \epsilon}
S_t =
\left(I - k_t(\lambda_t b_t \odot k_t)^\top\right)D_tS_{t-1}
+
k_t(\lambda_t w_t \odot v_t)^\top
It is not a standard Transformers checkpoint and does not use softmax attention or SWA layers.
Code
License
The model weights in this Hugging Face repository are released under Apache-2.0.
The standalone inference runtime linked above is also Apache-2.0. It does not
import lit_gpt, fla, or the NVIDIA GatedDeltaNet-2 Triton kernels. The
training code used during experimentation may contain NVIDIA GatedDeltaNet-2
derived components under Nvidia Source Code License-NC, but this Hugging Face
model repository is intended to be used with the standalone Apache-2.0 runtime.
Training Setup
- Base architecture: recurrent-only GDN-2, 1.3B scale
- Candidate: Kaczmarz-normalized GDN-2 gates
- Training data source: FineWeb-Edu
sample/100BTlocal parquet - Token budget for this run: 10B
- Current milestone: 10,000,000,000 tokens
- Sequence length: 4096 tokens
- Global batch tokens: 1,048,576
- Tokenizer:
TinyLlama/TinyLlama_v1.1 - Data shuffle seed:
3407 - Data shuffle buffer:
100000
Checkpoint Format
Each checkpoints/checkpoint-XXB/ folder contains:
model-ckpt.pth: PyTorch model-only checkpointtraining_metadata.json: run metadata and model configREADME.md: this model card snapshot
This is not loadable with transformers.AutoModelForCausalLM.from_pretrained.
How To Use
This is a causal language model: given a text prefix, it predicts the next token and can continue the text autoregressively. It was pretrained on FineWeb-Edu and is not instruction-tuned, RLHF-tuned, or chat-aligned.
The checkpoint is a PyTorch .pth checkpoint, not a
transformers.AutoModelForCausalLM checkpoint. Use the standalone runtime below
to load it.
Install and clone:
git clone https://github.com/gyunggyung/Gated_Linear_Attention2
cd Gated_Linear_Attention2
pip install -e .
Minimal text-generation example:
import torch
from gated_linear_attention2 import GatedLinearAttention2ForCausalLM, load_tokenizer
from gated_linear_attention2.generation import generate
repo_id = "gyung/Gated_Linear_Attention2"
checkpoint_file = "checkpoints/checkpoint-01B/model-ckpt.pth"
if not torch.cuda.is_available():
raise RuntimeError("CUDA is recommended for this 1.3B checkpoint; CPU will be very slow.")
device = "cuda"
dtype = torch.bfloat16
model = GatedLinearAttention2ForCausalLM.from_hf(
repo_id=repo_id,
checkpoint=checkpoint_file,
device=device,
dtype=dtype,
)
tokenizer = load_tokenizer(repo_id, subfolder="tokenizer")
prompt = "Artificial intelligence can help education by"
print(generate(model, tokenizer, prompt, max_new_tokens=80, temperature=0.8, top_k=50))
For next-token scoring instead of generation, run one forward pass and inspect the final-position logits:
prompt = "The capital of France is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
logits = model(input_ids)[:, -1, :]
next_token_id = int(torch.argmax(logits, dim=-1)[0])
print(tokenizer.decode([next_token_id]))
The standalone runtime uses a recurrent state cache during generation, so decode memory does not grow with generated token length like a Transformer KV cache.
Evaluation Plan
Compare against the plain gdn2_1.3B baseline on the GDN-2 paper tasks:
- WikiText and LAMBADA perplexity
- LAMBADA and commonsense zero-shot accuracy
- RULER S-NIAH and MK-NIAH
- Real-world retrieval tasks: SWDE, SQuAD, FDA, TriviaQA, NQ, DROP
The 10B run is an ablation, not a claim that it replaces the published 100B GDN-2 model.