Gated_Linear_Attention2

This repository stores milestone checkpoints for gdn2_kla_1.3B from the gdn2_kla_1.3B_fineweb_edu_10bt run.

What This Model Is

gdn2_kla_1.3B is a recurrent-only linear attention experiment. It starts from Gated DeltaNet-2 and folds a Kaczmarz-style key-norm-normalized update step into the separate erase and write gates:

\lambda_t = \frac{\eta_t}{\|k_t\|_2^2 + \epsilon}
S_t =
\left(I - k_t(\lambda_t b_t \odot k_t)^\top\right)D_tS_{t-1}
+
k_t(\lambda_t w_t \odot v_t)^\top

It is not a standard Transformers checkpoint and does not use softmax attention or SWA layers.

Code

License

The model weights in this Hugging Face repository are released under Apache-2.0.

The standalone inference runtime linked above is also Apache-2.0. It does not import lit_gpt, fla, or the NVIDIA GatedDeltaNet-2 Triton kernels. The training code used during experimentation may contain NVIDIA GatedDeltaNet-2 derived components under Nvidia Source Code License-NC, but this Hugging Face model repository is intended to be used with the standalone Apache-2.0 runtime.

Training Setup

  • Base architecture: recurrent-only GDN-2, 1.3B scale
  • Candidate: Kaczmarz-normalized GDN-2 gates
  • Training data source: FineWeb-Edu sample/100BT local parquet
  • Token budget for this run: 10B
  • Current milestone: 10,000,000,000 tokens
  • Sequence length: 4096 tokens
  • Global batch tokens: 1,048,576
  • Tokenizer: TinyLlama/TinyLlama_v1.1
  • Data shuffle seed: 3407
  • Data shuffle buffer: 100000

Checkpoint Format

Each checkpoints/checkpoint-XXB/ folder contains:

  • model-ckpt.pth: PyTorch model-only checkpoint
  • training_metadata.json: run metadata and model config
  • README.md: this model card snapshot

This is not loadable with transformers.AutoModelForCausalLM.from_pretrained.

How To Use

This is a causal language model: given a text prefix, it predicts the next token and can continue the text autoregressively. It was pretrained on FineWeb-Edu and is not instruction-tuned, RLHF-tuned, or chat-aligned.

The checkpoint is a PyTorch .pth checkpoint, not a transformers.AutoModelForCausalLM checkpoint. Use the standalone runtime below to load it.

Install and clone:

git clone https://github.com/gyunggyung/Gated_Linear_Attention2
cd Gated_Linear_Attention2
pip install -e .

Minimal text-generation example:

import torch

from gated_linear_attention2 import GatedLinearAttention2ForCausalLM, load_tokenizer
from gated_linear_attention2.generation import generate

repo_id = "gyung/Gated_Linear_Attention2"
checkpoint_file = "checkpoints/checkpoint-01B/model-ckpt.pth"

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is recommended for this 1.3B checkpoint; CPU will be very slow.")

device = "cuda"
dtype = torch.bfloat16

model = GatedLinearAttention2ForCausalLM.from_hf(
    repo_id=repo_id,
    checkpoint=checkpoint_file,
    device=device,
    dtype=dtype,
)
tokenizer = load_tokenizer(repo_id, subfolder="tokenizer")

prompt = "Artificial intelligence can help education by"
print(generate(model, tokenizer, prompt, max_new_tokens=80, temperature=0.8, top_k=50))

For next-token scoring instead of generation, run one forward pass and inspect the final-position logits:

prompt = "The capital of France is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
    logits = model(input_ids)[:, -1, :]
next_token_id = int(torch.argmax(logits, dim=-1)[0])
print(tokenizer.decode([next_token_id]))

The standalone runtime uses a recurrent state cache during generation, so decode memory does not grow with generated token length like a Transformer KV cache.

Evaluation Plan

Compare against the plain gdn2_1.3B baseline on the GDN-2 paper tasks:

  • WikiText and LAMBADA perplexity
  • LAMBADA and commonsense zero-shot accuracy
  • RULER S-NIAH and MK-NIAH
  • Real-world retrieval tasks: SWDE, SQuAD, FDA, TriviaQA, NQ, DROP

The 10B run is an ablation, not a claim that it replaces the published 100B GDN-2 model.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train gyung/Gated_Linear_Attention2