Topic Drift Detector Model

Version: v20241225_184257

This model detects topic drift in conversations using an enhanced hierarchical attention-based architecture. Trained on the leonvanbokhorst/topic-drift-v2 dataset.

Model Architecture

  • Multi-head attention mechanism (4 heads, head dimension 128)
  • Hierarchical pattern detection with multi-scale analysis
  • Explicit transition point detection with linguistic markers
  • Pattern-aware self-attention mechanism
  • Dynamic window augmentation
  • Contrastive learning with pattern-aware sampling
  • Adversarial training with pattern-aware perturbations

Key Components:

  1. Embedding Processor:

    • Input dimension: 1024
    • Hidden dimension: 512
    • Dropout rate: 0.35
    • PreNorm layers with residual connections
  2. Attention Blocks:

    • 3 layers of attention
    • 4 attention heads
    • Feed-forward dimension: 2048
    • Learned position encodings
  3. Pattern Detection:

    • Hierarchical LSTM layers
    • Bidirectional processing
    • Multi-scale pattern analysis
    • Pattern classification with 7 types
  4. Transition Detection:

    • Linguistic marker attention
    • Explicit transition scoring
    • Marker-based context integration

Performance Metrics

=== Full Training Results ===
Best Validation RMSE: 0.0142
Best Validation R²: 0.8711

=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0144
R²: 0.8666

Training Details

  • Dataset: 6400 conversations (5120 train, 640 val, 640 test)
  • Window size: 8 turns
  • Batch size: 32
  • Learning rate: 0.0001 with cosine decay
  • Warmup steps: 100
  • Early stopping patience: 15
  • Max gradient norm: 1.0
  • Mixed precision training (AMP)
  • Base embeddings: BAAI/bge-m3

Training Enhancements:

  1. Dynamic Window Augmentation:

    • Adaptive window sizes
    • Interpolation-based resizing
    • Maintains temporal consistency
  2. Contrastive Learning:

    • Pattern-aware positive/negative sampling
    • Temperature-scaled similarities
    • Weighted combination of embeddings
  3. Adversarial Training:

    • Pattern-aware perturbations
    • Self-distillation loss
    • Epsilon ball projection

Usage Example

import torch
from transformers import AutoModel, AutoTokenizer

# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')

# Load topic drift detector
model = torch.load('models/v20241225_184257/topic_drift_model.pt')
model.eval()

# Prepare conversation window (8 turns)
conversation = [
    "How was your weekend?",
    "It was great! Went hiking.",
    "Which trail did you take?",
    "The mountain loop trail.",
    "That's nice. By the way, did you watch the game?",
    "Yes! What an amazing match!",
    "The final score was incredible.",
    "I couldn't believe that last-minute goal."
]

# Get embeddings
with torch.no_grad():
    inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
    embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)  # [8, 1024]
    
    # Reshape for model input [1, 8*1024]
    conversation_embeddings = embeddings.view(1, -1)
    
    # Get drift score
    drift_scores = model(conversation_embeddings)
    
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift

Pattern Types

The model detects 7 distinct pattern types:

  1. "maintain" - No significant drift
  2. "gentle_wave" - Subtle topic evolution
  3. "single_peak" - One clear transition
  4. "multi_peak" - Multiple transitions
  5. "ascending" - Gradually increasing drift
  6. "descending" - Gradually decreasing drift
  7. "abrupt" - Sudden topic change

Limitations

  • Works best with English conversations
  • Requires exactly 8 turns of conversation
  • Each turn should be between 1-512 tokens
  • Relies on BAAI/bge-m3 embeddings
  • May be sensitive to conversation style variations

Training Curves

Training Curves

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train leonvanbokhorst/topic-drift-detector

Evaluation results