language: en
license: apache-2.0
tags:
- world-model
- rssm
- tutoring
- predictive-model
- pytorch
- kat
library_name: pytorch
pipeline_tag: reinforcement-learning
model-index:
- name: kat-2-RSSM
results:
- task:
type: world-modeling
name: Tutoring State Prediction
metrics:
- name: Eval Loss (best)
type: loss
value: 0.3124
- name: Reconstruction Loss
type: loss
value: 0.1389
- name: KL Divergence
type: loss
value: 0.0104
- name: Reward Loss
type: loss
value: 0.082
- name: Done Loss
type: loss
value: 0.064
KAT-2-RSSM
A Recurrent State-Space Model trained for tutoring state prediction, part of the KAT system by Progga AI.
Model Description
This is a complete world model for predicting tutoring session dynamics β student state transitions, reward signals, and session termination. It uses a DreamerV3-inspired RSSM architecture with VL-JEPA-style EMA target encoding.
Architecture
TutoringRSSM (2,802,838 params)
βββ ObservationEncoder: obs_dim(20) β encoder_hidden(256) β latent_dim(128)
βββ ActionEmbedding: action_dim(8) β embed_dim(32)
βββ DeterministicTransition: GRU(hidden_dim=512)
βββ StochasticLatent: Diagonal Gaussian prior/posterior (latent_dim=128)
βββ ObservationDecoder: feature_dim(640) β decoder_hidden(256) β obs_dim(20)
βββ RewardPredictor: feature_dim(640) β 1
βββ DonePredictor: feature_dim(640) β 1
βββ EMATargetEncoder: momentum=0.996 (VL-JEPA heritage)
Feature dimension: hidden_dim + latent_dim = 512 + 128 = 640
Observation Space (20-dim)
The 20-dimensional observation vector encodes tutoring session state:
| Dims | Signal |
|---|---|
| 0-3 | Mastery estimates (per-topic confidence) |
| 4-7 | Engagement signals (attention, participation) |
| 8-11 | Response quality (accuracy, depth, speed) |
| 12-15 | Emotional state (frustration, confidence, curiosity) |
| 16-19 | Session context (time, hint level, attempt count) |
Action Space (8 discrete actions)
| Index | Strategy |
|---|---|
| 0 | SOCRATIC β Guided questioning |
| 1 | SCAFFOLDED β Structured support |
| 2 | DIRECT β Direct instruction |
| 3 | EXPLORATORY β Open exploration |
| 4 | REMEDIAL β Error correction |
| 5 | ASSESSMENT β Knowledge check |
| 6 | MOTIVATIONAL β Encouragement |
| 7 | METACOGNITIVE β Reflection |
Training Details
- Data: 100,901 synthetic tutoring trajectories (95,856 train / 5,045 eval)
- Epochs: 100 (best at epoch 93)
- Hardware: NVIDIA A100-SXM4-40GB
- Optimizer: Adam (lr=3e-4)
- Training time: ~45 minutes
- Framework: PyTorch 2.x
Training Metrics (Best Checkpoint β Epoch 93)
| Metric | Value |
|---|---|
| Total Loss | 0.3124 |
| Reconstruction Loss | 0.1389 |
| KL Divergence | 0.0104 |
| Reward Loss | 0.0820 |
| Done Loss | 0.0640 |
| Rollout Loss | 0.3294 |
Training Curve
Training converged smoothly over 100 epochs with consistent eval loss improvement. No catastrophic forgetting or training instability observed.
Files
| File | Description | Size |
|---|---|---|
tutoring_rssm_best.pt |
Best checkpoint (epoch 93, eval loss 0.3124) | 11 MB |
tutoring_rssm_final.pt |
Final checkpoint (epoch 100) | 11 MB |
tutoring_rssm_epoch{N}.pt |
Snapshots every 10 epochs | 11 MB each |
v1-backup/ |
RSSM v1 checkpoints (smaller model) | ~800 KB each |
training_log.txt |
Full training log | ~8 KB |
config.json |
Model configuration | <1 KB |
architecture.py |
Standalone model definition | ~20 KB |
Usage
import torch
from architecture import TutoringRSSM, TutoringWorldModelConfig
# Load model
config = TutoringWorldModelConfig(
obs_dim=20, action_dim=8,
latent_dim=128, hidden_dim=512,
encoder_hidden=256, decoder_hidden=256,
)
model = TutoringRSSM(config).cuda()
ckpt = torch.load("tutoring_rssm_best.pt", map_location="cuda")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
# Initialize state
h, z = model.initial_state(batch_size=1)
# Observe a tutoring step
obs = torch.randn(1, 20).cuda() # Student observation
action = torch.tensor([0]).cuda() # SOCRATIC strategy
result = model.observe_step(h, z, action, obs)
h_new, z_new = result["h"], result["z"]
pred_obs = result["pred_obs"] # Predicted next observation
pred_reward = result["pred_reward"] # Predicted reward
pred_done = result["pred_done"] # Predicted session end
# Imagination (planning without observation)
imagined = model.imagine_step(h_new, z_new, torch.tensor([3]).cuda())
# Returns predicted state without requiring real observation
Evaluation Results (94/94 tests pass)
| Component | Tests | Status |
|---|---|---|
| Predictive Student Model | 44/44 | ALL PASS |
| Cognition World Model Eval | 2/2 | ALL ACCEPTANCE MET |
| Core PyTorch RSSM | 10/10 | ALL PASS |
| Physics/Causality Micro-Modules | 23/23 | ALL PASS |
| Trained Checkpoint Inference | 7/7 | ALL PASS |
| Advanced Planners (MCTS/Beam) | 8/8 | ALL PASS |
Acceptance Criteria
- Prediction accuracy: 12.08% error at horizon (target <20%) β
- Planning improvement: +14.5% vs reactive baseline (target >+10%) β
Heritage
This model inherits from the Abigail3 cognitive architecture, specifically:
- RSSM design from
abigail/core/world_model.py - VL-JEPA EMA target encoding from Meta AI's Joint-Embedding Predictive Architecture
- DreamerV3-inspired training with KL balancing and rollout losses
- Governance-first design: generation separated from governance
Ecosystem
This world model is part of the broader KAT system:
- 23 physics/causality micro-modules (67M params total) β intuitive physics simulation
- MCTS Planner β Monte Carlo Tree Search for action planning
- Beam Search Planner β Anytime approximate planning
- Causal World Model β Structural causal model with do-calculus
- Predictive Student Model β VL-JEPA/RSSM adapted for tutoring personalization
License
Apache 2.0
Author
Preston Mills β Progga AI
- Built for KAT-2 framework
- Designed by Progga AI
- February 2026