Topic Drift Detector Model
Version: v20241225_184257
This model detects topic drift in conversations using an enhanced hierarchical attention-based architecture. Trained on the leonvanbokhorst/topic-drift-v2 dataset.
Model Architecture
- Multi-head attention mechanism (4 heads, head dimension 128)
- Hierarchical pattern detection with multi-scale analysis
- Explicit transition point detection with linguistic markers
- Pattern-aware self-attention mechanism
- Dynamic window augmentation
- Contrastive learning with pattern-aware sampling
- Adversarial training with pattern-aware perturbations
Key Components:
Embedding Processor:
- Input dimension: 1024
- Hidden dimension: 512
- Dropout rate: 0.35
- PreNorm layers with residual connections
Attention Blocks:
- 3 layers of attention
- 4 attention heads
- Feed-forward dimension: 2048
- Learned position encodings
Pattern Detection:
- Hierarchical LSTM layers
- Bidirectional processing
- Multi-scale pattern analysis
- Pattern classification with 7 types
Transition Detection:
- Linguistic marker attention
- Explicit transition scoring
- Marker-based context integration
Performance Metrics
=== Full Training Results ===
Best Validation RMSE: 0.0142
Best Validation R²: 0.8711
=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0144
R²: 0.8666
Training Details
- Dataset: 6400 conversations (5120 train, 640 val, 640 test)
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001 with cosine decay
- Warmup steps: 100
- Early stopping patience: 15
- Max gradient norm: 1.0
- Mixed precision training (AMP)
- Base embeddings: BAAI/bge-m3
Training Enhancements:
Dynamic Window Augmentation:
- Adaptive window sizes
- Interpolation-based resizing
- Maintains temporal consistency
Contrastive Learning:
- Pattern-aware positive/negative sampling
- Temperature-scaled similarities
- Weighted combination of embeddings
Adversarial Training:
- Pattern-aware perturbations
- Self-distillation loss
- Epsilon ball projection
Usage Example
import torch
from transformers import AutoModel, AutoTokenizer
# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
# Load topic drift detector
model = torch.load('models/v20241225_184257/topic_drift_model.pt')
model.eval()
# Prepare conversation window (8 turns)
conversation = [
"How was your weekend?",
"It was great! Went hiking.",
"Which trail did you take?",
"The mountain loop trail.",
"That's nice. By the way, did you watch the game?",
"Yes! What an amazing match!",
"The final score was incredible.",
"I couldn't believe that last-minute goal."
]
# Get embeddings
with torch.no_grad():
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024]
# Reshape for model input [1, 8*1024]
conversation_embeddings = embeddings.view(1, -1)
# Get drift score
drift_scores = model(conversation_embeddings)
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift
Pattern Types
The model detects 7 distinct pattern types:
- "maintain" - No significant drift
- "gentle_wave" - Subtle topic evolution
- "single_peak" - One clear transition
- "multi_peak" - Multiple transitions
- "ascending" - Gradually increasing drift
- "descending" - Gradually decreasing drift
- "abrupt" - Sudden topic change
Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings
- May be sensitive to conversation style variations
Training Curves
Dataset used to train leonvanbokhorst/topic-drift-detector
Evaluation results
- Test RMSE on leonvanbokhorst/topic-drift-v2self-reported0.014
- Test R² on leonvanbokhorst/topic-drift-v2self-reported0.867
- Test Loss on leonvanbokhorst/topic-drift-v2self-reported0.000