Introspective Prisma-VL-8B Architecture
Overview
Prisma-VL-8B includes a introspective feedback mechanism that provides fine-grained self-monitoring uncertainty awareness to the model's predictions.
Core Innovation
The model now tracks its own prediction uncertainty and uses this as a feedback signal for subsequent predictions. This creates a temporal awareness loop:
Token t-1: "What's next?" β Prediction + Uncertainty measurement
Token t: [Previous uncertainty signal] + "What's next?" β Better calibrated prediction
Architecture Changes
1. Uncertainty Embeddings (PrismaVLModel)
Added to PrismaVLModel.__init__():
# 65,536-level uncertainty embedding table
self.n_bits = 16 # 16-bit quantization
self.n_uncertainty_levels = 65536 # 2^16
# Learned embeddings: one vector per uncertainty level
self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim)
# Cache for uncertainty codes from previous step
self.prev_uncertainty_code = None # [batch_size, seq_len] with values [0-65535]
Parameter cost: 65,536 Γ 4096 = 268,435,456 parameters (3.35% overhead)
2. Uncertainty Injection (PrismaVLModel.forward)
During forward pass, after creating input embeddings:
# Look up uncertainty embeddings from previous step
uncertainty_embeds = self.uncertainty_embeddings(prev_uncertainty_code)
# Shift right: position i gets uncertainty from position i-1
uncertainty_shifted = pad(uncertainty_embeds[:, :-1, :], (0,0,1,0))
# Inject into input
inputs_embeds = inputs_embeds + uncertainty_shifted
Now the model sees: [Token embedding] + [How uncertain was I last time?]
3. Uncertainty Computation (PrismaVLForConditionalGeneration.forward)
After computing logits, during training:
# Compute entropy (uncertainty) of predictions
probs = logits.softmax(-1)
entropy = -(probs * log(probs)).sum(-1)
# Normalize to [0, 1]
entropy_norm = entropy / log(vocab_size)
# Quantize to 16 bits (0-65535)
uncertainty_code = (entropy_norm * 65535).long()
# Store for next step
self.model.prev_uncertainty_code = uncertainty_code
How It Works (Step by Step)
Inference/Generation:
- Token 0: No previous uncertainty β Use neutral (32768)
- Token 1: Predict β Measure confidence β Encode as 0-65535
- Token 2: Inject uncertainty signal from Token 1 β Predict (now calibrated)
- Token 3: Inject uncertainty from Token 2 β Predict
- ... and so on
Training:
Model learns the uncertainty embeddings through backpropagation:
- Embedding #0-16383: "I was very confident" β Model learns to stay confident
- Embedding #16384-32767: "I had medium confidence" β Model learns moderate caution
- Embedding #32768-49151: "I was uncertain" β Model learns to hedge
- Embedding #49152-65535: "I was very uncertain" β Model learns to be conservative
Key Properties
1. Moderate Overhead
- Parameters: 268M additional (3.35% of 8B base)
- Memory: 2 bytes per token (uncertainty code)
- Compute: Negligible (one embedding lookup per token)
2. Temporal Awareness
- Model builds a "confidence history" across generation
- Can detect when it's going into unfamiliar territory
- Can recover calibration after uncertain predictions
3. Self-Calibration
- No external signals needed
- Model learns its own uncertainty language
- Improves through standard supervised training
4. Architecture-Agnostic
- Works with any transformer-based model
- Doesn't modify attention, FFN, or other core components
- Clean separation: uncertainty mechanism vs. base model
Usage
Standard Inference
from modeling import PrismaVLForConditionalGeneration
from transformers import AutoProcessor
# Load model (introspective mechanism is built-in)
model = PrismaVLForConditionalGeneration.from_pretrained(
".",
trust_remote_code=True,
dtype=torch.bfloat16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained(".", trust_remote_code=True)
# Use normally - uncertainty tracking happens automatically
messages = [{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": prompt}]}]
inputs = processor.apply_chat_template(messages, ...)
outputs = model.generate(**inputs)
Training
# Train normally - uncertainty mechanism learns automatically
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
# The uncertainty embeddings will learn to represent
# "how to adjust predictions based on previous confidence"
Resetting Uncertainty (Between Sequences)
# Reset uncertainty cache between independent generations
model.model.reset_uncertainty()
# Generate
outputs = model.generate(...)
What Gets Learned
The 65,536 uncertainty embedding vectors learn to encode:
Confidence Continuation:
- "Last token was confident" β Maintain confidence (if appropriate)
Uncertainty Propagation:
- "Last token was uncertain" β Be more conservative
Domain Shifts:
- Sequence of low uncertainty β sudden high uncertainty β Domain boundary detected
Recovery Patterns:
- High uncertainty β Gradual return to confidence β Model finding its footing
Benefits
- Better Calibration: Model knows when it doesn't know
- Hallucination Awareness: Uncertain predictions less likely to compound
- Adaptive Confidence: Can adjust based on recent performance
- Interpretability: Uncertainty codes provide insight into model state
- No Inference Cost: Only active during training (for computing new uncertainties)
Implementation Details
Files Modified
modeling.py:PrismaVLModel.__init__(): Add uncertainty embeddingsPrismaVLModel.forward(): Inject uncertainty signalPrismaVLForConditionalGeneration.forward(): Compute uncertainty- Added
reset_uncertainty()method
Initialization
- Uncertainty embeddings initialized with
std = config.text_config.initializer_range(typically 0.02) - Start neutral: first token uses code 128 (middle of range)
Compatibility
- Fully backward compatible: model can load existing checkpoints
- New uncertainty embeddings initialize randomly (will be trained)
- No changes to base model weights or architecture
Comparison to Original Llama 3.2 Example
Similarities:
- Entropy-based uncertainty measurement
- Temporal feedback loop
- Embedding-based uncertainty representation
Differences:
- Quantization: 16-bit (65,536 levels) vs. 8-bit (256 levels)
- Resolution: Fine-grained uncertainty vs. coarse-grained
- Overhead: 3.35% parameter overhead vs. ~0.04%
- Applied to: Vision-language model (Prisma-VL) vs. pure language model (Llama)
- Integration: Built into core architecture vs. wrapper class
- Scope: Uncertainty only for text generation (not vision encoding)
Future Enhancements
Potential extensions:
- Multi-resolution Uncertainty: Track uncertainty at token, word, and sentence levels
- Uncertainty-aware Generation: Sample less when uncertain (lower temperature)
- Visual Uncertainty: Extend mechanism to vision encoder
- Cross-modal Uncertainty: Track alignment confidence between vision and text
- Explicit Uncertainty Tokens: Add special tokens to express uncertainty in output
Citation
Inspired by temporal feedback loop patterns, enhanced with 16-bit high-resolution quantization for fine-grained uncertainty representation.
Model: Prisma-VL-8B Date: 2025 Architecture: Integrated 16-bit temporal uncertainty feedback mechanism Parameter Overhead: 268M (3.35%) Memory Overhead: 2 bytes/token