MemoryDecoder-GPT2-Small

Model Description

Memory Decoder is a pretrained, plug-and-play memory component designed for efficient domain adaptation of large language models. This checkpoint contains the GPT2-small Memory Decoder trained on WikiText-103, as described in our NeurIPS 2025 paper.

Overview

Memory Decoder bridges the gap between non-parametric retrieval methods and parametric fine-tuning approaches. By pre-training a compact transformer decoder to internalize retrieval patterns, it provides:

  • Plug-and-Play Integration: Works with any GPT2 model variant without modifying original parameters
  • Efficient Inference: No retrieval overhead - just parallel forward passes
  • Domain Expertise: Captures long-tail knowledge like kNN-LM but with parametric efficiency
  • Preserved Capabilities: Original model remains unchanged

Quick Start

Step 1: Import Libraries and Initialize Models

from memDec import MemoryDecoder
import transformers
from transformers import AutoModelForCausalLM
from loguru import logger

# Define paths to your models
base_lm_path = "gpt2-xl"  # or any GPT2 variant
knn_generator_path = "Clover-Hill/MemoryDecoder-gpt2-small"

# Load tokenizer and models
tokenizer = transformers.AutoTokenizer.from_pretrained(base_lm_path)
base_lm = AutoModelForCausalLM.from_pretrained(base_lm_path)
knn_generator = AutoModelForCausalLM.from_pretrained(knn_generator_path)

Step 2: Prepare Models and Create Joint Model

# Resize embeddings and set to evaluation mode
base_lm.eval()
knn_generator.eval()

# Create the joint Memory Decoder model
joint = MemoryDecoder(base_lm, knn_generator, lmbda=0.55, knn_temp=1.0).to("cuda")

Step 3: Generate Text and Compare Results

# Prepare input prompt
prompt = "As with previous Valkyira Chronicles games , Valkyria Chronicles III is"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Generate with Memory Decoder
out_ids = joint.generate(**inputs, max_new_tokens=20, do_sample=False)
logger.info(f"Memory Decoder output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")

# Generate with base model for comparison
out_ids = base_lm.generate(**inputs, max_new_tokens=20, do_sample=False)
logger.info(f"Base Model output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")

๐Ÿ“Š Generation Results Comparison:

Model Generated Continuation
Base Model "...is a turn-based strategy game. The player takes control of a squad of Valkyria soldiers..."
+Memory Decoder "...is a role-playing video game developed by Sega and published by Sega for the PlayStation 2."

Memory Decoder correctly identifies Valkyria Chronicles III as a role-playing game (factually accurate), while the base model incorrectly predicts it as a strategy game.

Performance on WikiText-103

Model Configuration Perplexity Improvement
GPT2-small (baseline) 24.89 -
GPT2-small + MemoryDecoder 13.36 -11.53
GPT2-medium (baseline) 18.29 -
GPT2-medium + MemoryDecoder 12.25 -6.04
GPT2-large (baseline) 15.80 -
GPT2-large + MemoryDecoder 11.53 -4.27
GPT2-xl (baseline) 14.39 -
GPT2-xl + MemoryDecoder 10.93 -3.46

Key Features

  • Universal Compatibility: Works with all GPT2 model sizes (small, medium, large, xl)
  • Parameter Efficient: Only 124M additional parameters enhance models up to 1.5B
  • Domain Adaptation: Trained to capture WikiText-103 domain knowledge
  • Inference Speed: Minimal overhead compared to retrieval-based methods

Training Details

  • Training Data: WikiText-103
  • Training Objective: Hybrid KL divergence and language modeling loss
  • Supervision Signal: kNN distributions from GPT2-xl, it is suggested to use the finetuned version of GPT2-xl here.
  • Hyperparameters:
    • Learning rate: 1e-3
    • Beta (loss balance): 0.5
    • Training Epoch: 70

Citation

@article{cao2025memory,
  title={Memory decoder: A pretrained, plug-and-play memory for large language models},
  author={Cao, Jiaqi and Wang, Jiarui and Wei, Rubin and Guo, Qipeng and Chen, Kai and Zhou, Bowen and Lin, Zhouhan},
  journal={arXiv preprint arXiv:2508.09874},
  year={2025}
}

Contact

For questions and support: maximus.cao@outlook.com

Downloads last month
85
Safetensors
Model size
124M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for Clover-Hill/MemoryDecoder-gpt2-small

Finetuned
(1883)
this model

Collection including Clover-Hill/MemoryDecoder-gpt2-small