Prisma-VL-8B / INTROSPECTIVE_ARCHITECTURE.md
ehartford's picture
Upload folder using huggingface_hub
5154f51 verified

Introspective Prisma-VL-8B Architecture

Overview

Prisma-VL-8B includes a introspective feedback mechanism that provides fine-grained self-monitoring uncertainty awareness to the model's predictions.

Core Innovation

The model now tracks its own prediction uncertainty and uses this as a feedback signal for subsequent predictions. This creates a temporal awareness loop:

Token t-1: "What's next?" β†’ Prediction + Uncertainty measurement
Token t:   [Previous uncertainty signal] + "What's next?" β†’ Better calibrated prediction

Architecture Changes

1. Uncertainty Embeddings (PrismaVLModel)

Added to PrismaVLModel.__init__():

# 65,536-level uncertainty embedding table
self.n_bits = 16  # 16-bit quantization
self.n_uncertainty_levels = 65536  # 2^16

# Learned embeddings: one vector per uncertainty level
self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim)

# Cache for uncertainty codes from previous step
self.prev_uncertainty_code = None  # [batch_size, seq_len] with values [0-65535]

Parameter cost: 65,536 Γ— 4096 = 268,435,456 parameters (3.35% overhead)

2. Uncertainty Injection (PrismaVLModel.forward)

During forward pass, after creating input embeddings:

# Look up uncertainty embeddings from previous step
uncertainty_embeds = self.uncertainty_embeddings(prev_uncertainty_code)

# Shift right: position i gets uncertainty from position i-1
uncertainty_shifted = pad(uncertainty_embeds[:, :-1, :], (0,0,1,0))

# Inject into input
inputs_embeds = inputs_embeds + uncertainty_shifted

Now the model sees: [Token embedding] + [How uncertain was I last time?]

3. Uncertainty Computation (PrismaVLForConditionalGeneration.forward)

After computing logits, during training:

# Compute entropy (uncertainty) of predictions
probs = logits.softmax(-1)
entropy = -(probs * log(probs)).sum(-1)

# Normalize to [0, 1]
entropy_norm = entropy / log(vocab_size)

# Quantize to 16 bits (0-65535)
uncertainty_code = (entropy_norm * 65535).long()

# Store for next step
self.model.prev_uncertainty_code = uncertainty_code

How It Works (Step by Step)

Inference/Generation:

  1. Token 0: No previous uncertainty β†’ Use neutral (32768)
  2. Token 1: Predict β†’ Measure confidence β†’ Encode as 0-65535
  3. Token 2: Inject uncertainty signal from Token 1 β†’ Predict (now calibrated)
  4. Token 3: Inject uncertainty from Token 2 β†’ Predict
  5. ... and so on

Training:

Model learns the uncertainty embeddings through backpropagation:

  • Embedding #0-16383: "I was very confident" β†’ Model learns to stay confident
  • Embedding #16384-32767: "I had medium confidence" β†’ Model learns moderate caution
  • Embedding #32768-49151: "I was uncertain" β†’ Model learns to hedge
  • Embedding #49152-65535: "I was very uncertain" β†’ Model learns to be conservative

Key Properties

1. Moderate Overhead

  • Parameters: 268M additional (3.35% of 8B base)
  • Memory: 2 bytes per token (uncertainty code)
  • Compute: Negligible (one embedding lookup per token)

2. Temporal Awareness

  • Model builds a "confidence history" across generation
  • Can detect when it's going into unfamiliar territory
  • Can recover calibration after uncertain predictions

3. Self-Calibration

  • No external signals needed
  • Model learns its own uncertainty language
  • Improves through standard supervised training

4. Architecture-Agnostic

  • Works with any transformer-based model
  • Doesn't modify attention, FFN, or other core components
  • Clean separation: uncertainty mechanism vs. base model

Usage

Standard Inference

from modeling import PrismaVLForConditionalGeneration
from transformers import AutoProcessor

# Load model (introspective mechanism is built-in)
model = PrismaVLForConditionalGeneration.from_pretrained(
    ".",
    trust_remote_code=True,
    dtype=torch.bfloat16,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained(".", trust_remote_code=True)

# Use normally - uncertainty tracking happens automatically
messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}]
inputs = processor.apply_chat_template(messages, ...)
outputs = model.generate(**inputs)

Training

# Train normally - uncertainty mechanism learns automatically
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

for batch in dataloader:
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()

# The uncertainty embeddings will learn to represent
# "how to adjust predictions based on previous confidence"

Resetting Uncertainty (Between Sequences)

# Reset uncertainty cache between independent generations
model.model.reset_uncertainty()

# Generate
outputs = model.generate(...)

What Gets Learned

The 65,536 uncertainty embedding vectors learn to encode:

  1. Confidence Continuation:

    • "Last token was confident" β†’ Maintain confidence (if appropriate)
  2. Uncertainty Propagation:

    • "Last token was uncertain" β†’ Be more conservative
  3. Domain Shifts:

    • Sequence of low uncertainty β†’ sudden high uncertainty β†’ Domain boundary detected
  4. Recovery Patterns:

    • High uncertainty β†’ Gradual return to confidence β†’ Model finding its footing

Benefits

  1. Better Calibration: Model knows when it doesn't know
  2. Hallucination Awareness: Uncertain predictions less likely to compound
  3. Adaptive Confidence: Can adjust based on recent performance
  4. Interpretability: Uncertainty codes provide insight into model state
  5. No Inference Cost: Only active during training (for computing new uncertainties)

Implementation Details

Files Modified

  • modeling.py:
    • PrismaVLModel.__init__(): Add uncertainty embeddings
    • PrismaVLModel.forward(): Inject uncertainty signal
    • PrismaVLForConditionalGeneration.forward(): Compute uncertainty
    • Added reset_uncertainty() method

Initialization

  • Uncertainty embeddings initialized with std = config.text_config.initializer_range (typically 0.02)
  • Start neutral: first token uses code 128 (middle of range)

Compatibility

  • Fully backward compatible: model can load existing checkpoints
  • New uncertainty embeddings initialize randomly (will be trained)
  • No changes to base model weights or architecture

Comparison to Original Llama 3.2 Example

Similarities:

  • Entropy-based uncertainty measurement
  • Temporal feedback loop
  • Embedding-based uncertainty representation

Differences:

  • Quantization: 16-bit (65,536 levels) vs. 8-bit (256 levels)
  • Resolution: Fine-grained uncertainty vs. coarse-grained
  • Overhead: 3.35% parameter overhead vs. ~0.04%
  • Applied to: Vision-language model (Prisma-VL) vs. pure language model (Llama)
  • Integration: Built into core architecture vs. wrapper class
  • Scope: Uncertainty only for text generation (not vision encoding)

Future Enhancements

Potential extensions:

  1. Multi-resolution Uncertainty: Track uncertainty at token, word, and sentence levels
  2. Uncertainty-aware Generation: Sample less when uncertain (lower temperature)
  3. Visual Uncertainty: Extend mechanism to vision encoder
  4. Cross-modal Uncertainty: Track alignment confidence between vision and text
  5. Explicit Uncertainty Tokens: Add special tokens to express uncertainty in output

Citation

Inspired by temporal feedback loop patterns, enhanced with 16-bit high-resolution quantization for fine-grained uncertainty representation.


Model: Prisma-VL-8B Date: 2025 Architecture: Integrated 16-bit temporal uncertainty feedback mechanism Parameter Overhead: 268M (3.35%) Memory Overhead: 2 bytes/token