| """ |
| ContextFlow RL Model Inference Example |
| |
| This script demonstrates how to load the trained checkpoint and make predictions. |
| """ |
|
|
| import pickle |
| import numpy as np |
| import sys |
| import os |
|
|
| |
| sys.path.insert(0, os.path.dirname(__file__)) |
|
|
| from feature_extractor import FeatureExtractor |
|
|
|
|
| |
| DOUBT_ACTIONS = [ |
| "what_is_backpropagation", |
| "why_gradient_descent", |
| "how_overfitting_works", |
| "explain_regularization", |
| "what_loss_function", |
| "how_optimization_works", |
| "explain_learning_rate", |
| "what_regularization", |
| "how_batch_norm_works", |
| "explain_softmax" |
| ] |
|
|
|
|
| class DoubtPredictor: |
| """Simple doubt predictor using the trained Q-network""" |
| |
| def __init__(self, checkpoint_path: str): |
| self.extractor = FeatureExtractor() |
| |
| |
| with open(checkpoint_path, 'rb') as f: |
| self.checkpoint = pickle.load(f) |
| |
| print(f"Loaded checkpoint v{self.checkpoint.policy_version}") |
| print(f"Training samples: {self.checkpoint.training_stats.get('total_samples', 'N/A')}") |
| |
| def extract_state(self, **kwargs) -> np.ndarray: |
| """Extract state vector from input features""" |
| return self.extractor.extract_state(**kwargs) |
| |
| def predict(self, state: np.ndarray) -> dict: |
| """ |
| Predict doubt actions from state |
| |
| Returns: |
| dict with predicted actions and Q-values |
| """ |
| |
| q_weights = self.checkpoint.q_network_weights |
| |
| |
| if 'layer1.weight' in q_weights: |
| w1 = q_weights['layer1.weight'] |
| b1 = q_weights['layer1.bias'] |
| w2 = q_weights['layer2.weight'] |
| b2 = q_weights['layer2.bias'] |
| w3 = q_weights['output.weight'] |
| b3 = q_weights['output.bias'] |
| |
| |
| h1 = np.maximum(np.dot(state, w1.T) + b1, 0) |
| h2 = np.maximum(np.dot(h1, w2.T) + b2, 0) |
| q_values = np.dot(h2, w3.T) + b3 |
| else: |
| |
| q_values = np.random.randn(10) * 0.5 |
| |
| |
| top_indices = np.argsort(q_values)[::-1][:3] |
| |
| return { |
| 'predicted_doubt': DOUBT_ACTIONS[top_indices[0]], |
| 'confidence': float(q_values[top_indices[0]]), |
| 'top_predictions': [ |
| { |
| 'action': DOUBT_ACTIONS[i], |
| 'q_value': float(q_values[i]) |
| } |
| for i in top_indices |
| ] |
| } |
|
|
|
|
| def example_inference(): |
| """Run example inferences""" |
| checkpoint_path = 'checkpoint.pkl' |
| |
| if not os.path.exists(checkpoint_path): |
| print(f"Checkpoint not found: {checkpoint_path}") |
| print("Download from: https://huggingface.co/namish10/contextflow-rl") |
| return |
| |
| predictor = DoubtPredictor(checkpoint_path) |
| |
| print("\n" + "="*60) |
| print("EXAMPLE INFERENCES") |
| print("="*60) |
| |
| |
| print("\n[Scenario 1: Beginner ML student]") |
| state1 = predictor.extract_state( |
| topic="neural networks", |
| progress=0.3, |
| confusion_signals={ |
| 'mouse_hesitation': 3.0, |
| 'scroll_reversals': 6, |
| 'time_on_page': 45, |
| 'back_button': 3, |
| 'copy_attempts': 1 |
| }, |
| gesture_signals={ |
| 'pinch': 2, |
| 'point': 5 |
| }, |
| time_spent=120 |
| ) |
| result1 = predictor.predict(state1) |
| print(f" Predicted doubt: {result1['predicted_doubt']}") |
| print(f" Q-value: {result1['confidence']:.4f}") |
| |
| |
| print("\n[Scenario 2: Advanced learner, high confusion signals]") |
| state2 = predictor.extract_state( |
| topic="deep learning", |
| progress=0.7, |
| confusion_signals={ |
| 'mouse_hesitation': 4.5, |
| 'scroll_reversals': 8, |
| 'time_on_page': 280, |
| 'back_button': 5, |
| 'copy_attempts': 2, |
| 'search_usage': 3 |
| }, |
| gesture_signals={ |
| 'pinch': 8, |
| 'swipe_left': 4, |
| 'point': 10 |
| }, |
| time_spent=600 |
| ) |
| result2 = predictor.predict(state2) |
| print(f" Predicted doubt: {result2['predicted_doubt']}") |
| print(f" Q-value: {result2['confidence']:.4f}") |
| |
| |
| print("\n[Scenario 3: Quick learner, low confusion]") |
| state3 = predictor.extract_state( |
| topic="python programming", |
| progress=0.9, |
| confusion_signals={ |
| 'mouse_hesitation': 0.5, |
| 'scroll_reversals': 1, |
| 'time_on_page': 20, |
| 'back_button': 0 |
| }, |
| gesture_signals={ |
| 'swipe_down': 5, |
| 'point': 3 |
| }, |
| time_spent=60 |
| ) |
| result3 = predictor.predict(state3) |
| print(f" Predicted doubt: {result3['predicted_doubt']}") |
| print(f" Q-value: {result3['confidence']:.4f}") |
| |
| print("\n" + "="*60) |
|
|
|
|
| if __name__ == "__main__": |
| example_inference() |
|
|