leonvanbokhorst commited on
Commit
f9d091d
·
verified ·
1 Parent(s): 958a8fb

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +31 -55
README.md CHANGED
@@ -5,7 +5,6 @@ tags:
5
  - conversation-analysis
6
  - pytorch
7
  - attention
8
- - lstm
9
  license: mit
10
  datasets:
11
  - leonvanbokhorst/topic-drift-v2
@@ -35,18 +34,16 @@ model-index:
35
 
36
  # Topic Drift Detector Model
37
 
38
- ## Version: v20241225_184257
39
 
40
- This model detects topic drift in conversations using an enhanced hierarchical attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
41
 
42
  ## Model Architecture
43
- - Multi-head attention mechanism (4 heads, head dimension 128)
44
- - Hierarchical pattern detection with multi-scale analysis
45
- - Explicit transition point detection with linguistic markers
46
- - Pattern-aware self-attention mechanism
47
- - Dynamic window augmentation
48
- - Contrastive learning with pattern-aware sampling
49
- - Adversarial training with pattern-aware perturbations
50
 
51
  ### Key Components:
52
  1. **Embedding Processor**:
@@ -55,22 +52,16 @@ This model detects topic drift in conversations using an enhanced hierarchical a
55
  - Dropout rate: 0.35
56
  - PreNorm layers with residual connections
57
 
58
- 2. **Attention Blocks**:
59
- - 3 layers of attention
60
- - 4 attention heads
61
- - Feed-forward dimension: 2048
62
  - Learned position encodings
 
63
 
64
- 3. **Pattern Detection**:
65
- - Hierarchical LSTM layers
66
- - Bidirectional processing
67
- - Multi-scale pattern analysis
68
- - Pattern classification with 7 types
69
-
70
- 4. **Transition Detection**:
71
- - Linguistic marker attention
72
- - Explicit transition scoring
73
- - Marker-based context integration
74
 
75
  ## Performance Metrics
76
  ```txt
@@ -88,28 +79,24 @@ R²: 0.8666
88
  - Dataset: 6400 conversations (5120 train, 640 val, 640 test)
89
  - Window size: 8 turns
90
  - Batch size: 32
91
- - Learning rate: 0.0001 with cosine decay
92
- - Warmup steps: 100
93
  - Early stopping patience: 15
94
- - Max gradient norm: 1.0
95
- - Mixed precision training (AMP)
96
  - Base embeddings: BAAI/bge-m3
97
 
98
- ### Training Enhancements:
99
- 1. **Dynamic Window Augmentation**:
100
- - Adaptive window sizes
101
- - Interpolation-based resizing
102
- - Maintains temporal consistency
103
-
104
- 2. **Contrastive Learning**:
105
- - Pattern-aware positive/negative sampling
106
- - Temperature-scaled similarities
107
- - Weighted combination of embeddings
108
 
109
- 3. **Adversarial Training**:
110
- - Pattern-aware perturbations
111
- - Self-distillation loss
112
- - Epsilon ball projection
 
113
 
114
  ## Usage Example
115
  ```python
@@ -121,7 +108,7 @@ base_model = AutoModel.from_pretrained('BAAI/bge-m3')
121
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
122
 
123
  # Load topic drift detector
124
- model = torch.load('models/v20241225_184257/topic_drift_model.pt')
125
  model.eval()
126
 
127
  # Prepare conversation window (8 turns)
@@ -151,22 +138,11 @@ print(f"Topic drift score: {drift_scores.item():.4f}")
151
  # Higher scores indicate more topic drift
152
  ```
153
 
154
- ## Pattern Types
155
- The model detects 7 distinct pattern types:
156
- 1. "maintain" - No significant drift
157
- 2. "gentle_wave" - Subtle topic evolution
158
- 3. "single_peak" - One clear transition
159
- 4. "multi_peak" - Multiple transitions
160
- 5. "ascending" - Gradually increasing drift
161
- 6. "descending" - Gradually decreasing drift
162
- 7. "abrupt" - Sudden topic change
163
-
164
  ## Limitations
165
  - Works best with English conversations
166
  - Requires exactly 8 turns of conversation
167
  - Each turn should be between 1-512 tokens
168
  - Relies on BAAI/bge-m3 embeddings
169
- - May be sensitive to conversation style variations
170
 
171
  ## Training Curves
172
- ![Training Curves](plots/v20241225_184257/training_curves.png)
 
5
  - conversation-analysis
6
  - pytorch
7
  - attention
 
8
  license: mit
9
  datasets:
10
  - leonvanbokhorst/topic-drift-v2
 
34
 
35
  # Topic Drift Detector Model
36
 
37
+ ## Version: v20241226_105737
38
 
39
+ This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
40
 
41
  ## Model Architecture
42
+ - Efficient single-layer attention mechanism
43
+ - Direct pattern recognition
44
+ - Streamlined processing pipeline
45
+ - Optimized scaling factor (4.0)
46
+ - PreNorm layers with residual connections
 
 
47
 
48
  ### Key Components:
49
  1. **Embedding Processor**:
 
52
  - Dropout rate: 0.35
53
  - PreNorm layers with residual connections
54
 
55
+ 2. **Attention Block**:
56
+ - Single attention layer
57
+ - Feed-forward dimension: 512
 
58
  - Learned position encodings
59
+ - Residual connections
60
 
61
+ 3. **Pattern Recognition**:
62
+ - Direct feature extraction
63
+ - Efficient tensor operations
64
+ - Optimized memory usage
 
 
 
 
 
 
65
 
66
  ## Performance Metrics
67
  ```txt
 
79
  - Dataset: 6400 conversations (5120 train, 640 val, 640 test)
80
  - Window size: 8 turns
81
  - Batch size: 32
82
+ - Learning rate: 0.0001
 
83
  - Early stopping patience: 15
84
+ - Distribution regularization weight: 0.1
85
+ - Target standard deviation: 0.2
86
  - Base embeddings: BAAI/bge-m3
87
 
88
+ ## Key Improvements
89
+ 1. **Simplified Architecture**:
90
+ - Reduced complexity
91
+ - Focused pattern detection
92
+ - Efficient processing
93
+ - Optimized memory usage
 
 
 
 
94
 
95
+ 2. **Performance Benefits**:
96
+ - Improved RMSE (0.0144)
97
+ - Strong R² score (0.8666)
98
+ - Consistent predictions
99
+ - Wide score range
100
 
101
  ## Usage Example
102
  ```python
 
108
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
109
 
110
  # Load topic drift detector
111
+ model = torch.load('models/v20241226_105737/topic_drift_model.pt')
112
  model.eval()
113
 
114
  # Prepare conversation window (8 turns)
 
138
  # Higher scores indicate more topic drift
139
  ```
140
 
 
 
 
 
 
 
 
 
 
 
141
  ## Limitations
142
  - Works best with English conversations
143
  - Requires exactly 8 turns of conversation
144
  - Each turn should be between 1-512 tokens
145
  - Relies on BAAI/bge-m3 embeddings
 
146
 
147
  ## Training Curves
148
+ ![Training Curves](plots/v20241226_105737/training_curves.png)