ianblenke's picture
Exp 164: upload guided-decoding-adapter (Exp 137 artifact)
3727dac verified
metadata
tags:
  - energy-based-model
  - guided-decoding
  - constraint-satisfaction
  - jax
  - carnot
license: apache-2.0

Research Artifact β€” Not Production-Ready

Real-model validation is pending (Exp-111). Exp-110 results use a mock LLM with deterministic error injection. The constraint checker works correctly (0.006 ms/check on CPU); the guidance logic is unvalidated on live models.

guided-decoding-adapter

Energy-guided decoding adapter for any HuggingFace causal LM.

Attaches Carnot's constraint energy pipeline to the token generation loop. Each token step runs a constraint violation check on the text generated so far; violating tokens are penalised by subtracting alpha Γ— violation_count from all logits before sampling.

How It Works

prompt β†’ encode β†’ [forward pass β†’ check constraints β†’ penalise logits β†’ sample] Γ— N β†’ text

The constraint checker (AutoExtractor) detects violations across four domains:

Domain Constraint types
Arithmetic addition, multiplication, bounds
Code type checks, return types, initialisation
Logic implication, exclusion, disjunction, negation, universal
Natural language NL consistency

Energy is a plain violation count (not a calibrated probability). The penalty is applied uniformly across the vocabulary β€” token ranking is preserved while overall entropy increases, discouraging the model from continuing down a constraint-violating path.

Latency Profile

From Exp-102 (CPU, JAX_PLATFORMS=cpu, 1000-iteration benchmark):

Measurement Value
Constraint check p50 0.006 ms
Constraint check p99 0.034 ms
Extraction p50 0.276 ms
Per-token budget fraction 0.04% of 20 ms/token
Verdict Fits in real-time generation budget

Usage

from carnot.inference.guided_decoding import GuidedDecoder
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model (any HF causal LM)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-0.8B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-0.8B")
model.eval()

# Load adapter from local directory or HuggingFace Hub
decoder = GuidedDecoder.from_pretrained("Carnot-EBM/guided-decoding-adapter")

# Generate with constraint guidance
result = decoder.generate(model, tokenizer, "What is 47 + 28?")
print(result.text)
print(f"Energy checks: {result.energy_checks}, final energy: {result.final_energy}")

Override defaults

decoder = GuidedDecoder.from_pretrained(
    "Carnot-EBM/guided-decoding-adapter",
    alpha=1.0,           # stronger guidance
    check_every_k=5,     # check every 5 tokens (faster, less precise)
    energy_threshold=0.5 # only penalise when violations > 0.5
)

Load from a local export directory

decoder = GuidedDecoder.from_pretrained("./exports/guided-decoding-adapter")

Return Value

generate() returns a GuidedDecodingResult:

Field Type Description
text str Generated text (prompt excluded)
tokens_generated int Number of tokens produced
energy_checks int Times constraint check ran
mean_penalty float Average logit penalty applied
latency_seconds float Wall-clock time
final_energy float Violation count after last check

Constraint Weights

Default weights are stored in constraint_weights.safetensors. Load and inspect:

from safetensors.numpy import load_file
weights = load_file("constraint_weights.safetensors")
print(weights["all_weights"])   # shape (12,) float32
print(weights["default_alpha"]) # [0.5]

Compatible Models

Tested target models (Exp-110):

  • Qwen/Qwen3.5-0.8B
  • google/gemma-4-E4B-it

Any HuggingFace AutoModelForCausalLM with .logits output should work. The adapter does not modify model weights.

Benchmark Results (Exp-138 & Exp-140)

Note β€” Simulated Inference: All benchmark numbers below were produced with a simulated (mock) LLM, not a real transformer model. The constraint checker and logit-penalty logic are real; the generation loop uses a deterministic stand-in. Live-model E2E validation is pending (Exp-111).

Accuracy (Exp-138, n=200/50/100, simulated inference)

Dataset Baseline Guided Guided+Verify-Repair Delta (guided)
GSM8K (math) 55.5% 62.5% 65.0% +7.0%
HumanEval (code) 100.0% 100.0% β€” +0.0%
TruthfulQA 55.0% 56.0% 61.0% +1.0%

Latency (Exp-138, n=485 samples, CPU)

Metric Value
Constraint-check p50 0.0719 ms
Constraint-check p99 0.1275 ms

Latency β€” KAN Projection Mode (Exp-140, batch=1, CPU)

Operation p50 p99
Logit projection (energy gradient) 0.077 ms 0.271 ms
Total per-token (grad + projection) 0.405 ms 0.924 ms

Exp-140 pass criterion: total p50 < 5 ms β€” PASSED (actual 0.4054 ms vs 5.0 ms threshold).

Installation

pip install carnot

Requires Python 3.11+. See pypi.org/project/carnot for the full package including the verify-repair pipeline.

Limitations

  1. Simulated inference benchmark: Exp-138 and Exp-140 used a mock LLM. Numbers show constraint-checker and logit-penalty overhead, not end-to-end accuracy on real models. Treat accuracy deltas as directional, not final.
  2. No KV-cache: Full forward pass every token. Keep max_tokens < 256.
  3. Uniform penalty: Adjusts entropy across the whole vocabulary; does not steer towards specific correct tokens.
  4. Energy is a violation count: Not a calibrated probability. High alpha
    • many violations β†’ very flat distribution (model may repeat or stall).
  5. Min-text guard: AutoExtractor skips texts < 5 chars (early tokens).
  6. Live-model E2E pending: Exp-111 validation against Qwen/Gemma not done yet.

Spec

  • REQ-VERIFY-001: Constraint energy computed from partial text at each step.
  • SCENARIO-VERIFY-004: Energy penalises logits before sampling.

Citation

@misc{carnot2026guided,
  title  = {Carnot Guided Decoding Adapter},
  author = {Carnot-EBM},
  year   = {2026},
  url    = {https://github.com/Carnot-EBM/carnot-ebm}
}