NDStein commited on
Commit
58d3955
·
verified ·
1 Parent(s): 6bc022c

Upload 10 files

Browse files
Files changed (10) hide show
  1. README.md +218 -3
  2. config.py +71 -0
  3. dam3.1.ckpt +3 -0
  4. featex.py +119 -0
  5. model.py +185 -0
  6. pipeline.py +43 -0
  7. requirements.txt +5 -0
  8. tuning/__init__.py +0 -0
  9. tuning/indet_roc.py +416 -0
  10. tuning/optimal_ordinal.py +510 -0
README.md CHANGED
@@ -1,3 +1,218 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ base_model:
6
+ - openai/whisper-small.en
7
+ pipeline_tag: audio-classification
8
+ ---
9
+
10
+
11
+ # Background
12
+
13
+ In the United States nearly 21M adults suffer from depression each year [1], with depression serving as the nation’s leading cause of disability [2].
14
+ Despite this, less than 4% of Americans receive mental health screenings from their primary care physicians during annual wellness visits.
15
+ The pandemic and public campaigns of late have made strides toward positively increasing awareness of mental health struggles, but there remains a persisting stigma around depression and other mental health conditions.
16
+ The influence of this stigma is especially marked in older adults. People aged 65 and older are less likely than any other age group to seek mental health support.
17
+ Older adults – for whom depression significantly increases the risk of disability and morbidity – also tend to underreport mental health symptoms [3].
18
+
19
+ In the US, this outlook becomes even more troubling when coupled with the rate at which the country’s population is aging: 1 out of every 6 people will be 60 years or over by 2030 [4].
20
+ As widespread and prevalent as depression is, identifying and treating depression and other mental health conditions remains challenging and there is limited objectivity in the screening processes.
21
+
22
+
23
+ # Depression–Anxiety Model (DAM)
24
+
25
+ ## Model Overview
26
+
27
+ DAM is a clinical-grade, speech-based model designed to screen for signs of depression and anxiety using voice biomarkers.
28
+ To the best of our knowledge, it is the first model developed explicitly for clinical-grade mental health assessment from speech without reliance on linguistic content or transcription.
29
+ The model operates exclusively on the acoustic properties of the speech signal, extracting depression- and anxiety-specific voice biomarkers rather than semantic or lexical information.
30
+ Numerous studies [5–7] have demonstrated that paralinguistic features – such as spectral entropy, pitch variability, fundamental frequency, and related acoustic measures – exhibit strong correlations with depression and anxiety.
31
+ Building on this body of evidence, DAM extends prior approaches by leveraging deep learning to learn fine-grained vocal biomarkers directly from the raw speech signal, yielding representations that demonstrate greater predictive power than hand-engineered paralinguistic features.
32
+ DAM analyzes spoken audio to estimate depression and anxiety severity scores which can be subsequently mapped to standardized clinical scales, such as **PHQ-9** (Patient Health Questionnaire-9) for depression and **GAD-7** (Generalized Anxiety Disorder-7) for anxiety.
33
+
34
+
35
+ ## Data
36
+
37
+ The model was trained and evaluated on a large-scale speech dataset collected from approximately 35,000 individuals via phone, tablet, or web app, which corresponds to ~863 hours of speech data.
38
+ Ground-truth labels were derived from both clinician-administered and self-reported PHQ-9 and GAD-7 questionnaires, ensuring strong alignment with established clinical assessment standards.
39
+ The data consists predominantly of American English speech. However, a broad range of accents is represented, providing robustness across diverse speaking styles.
40
+
41
+ The audio data itself cannot be shared for privacy reasons. Demographic statistics, model scores, and associated metadata for each audio stream are available for threshold tuning at https://huggingface.co/datasets/KintsugiHealth/dam-dataset.
42
+
43
+
44
+ ## Model Architecture
45
+
46
+ **Foundation model:** OpenAI Whisper-Small EN
47
+
48
+ **Training approach:** Fine-tuning + Multi-task learning
49
+
50
+ **Downstream tasks:** Depression and anxiety severity estimation
51
+
52
+ Whisper serves as the backbone for extracting voice biomarkers, while multi-task head is fine-tuned jointly on depression and anxiety prediction tasks to leverage shared representations across mental health conditions.
53
+
54
+ ## Input Requirements
55
+
56
+ **Preferred minimum audio length:** 30 seconds of speech after Voice Activity Detector
57
+
58
+ **Input modality:** Audio only
59
+
60
+ Shorter audio samples may lead to reduced prediction accuracy.
61
+
62
+ ## Output
63
+
64
+ The model outputs a dictionary of the following form `{"depression":score, "anxiety": score}`.
65
+
66
+ If `quantized=False` (see the Usage section below), the scores are returned as raw float values which correlate monotonically with PHQ-9 and GAD-7.
67
+
68
+ If `quantized=True` the scores are converted into integers representing the severity of depression and anxiety.
69
+
70
+ **Quantization levels for depression task:**
71
+
72
+ 0 – no depression (PHQ-9 <= 9)
73
+
74
+ 1 – mild to moderate depression (10 <= PHQ-9 <= 14)
75
+
76
+ 2 – severe depression (PHQ-9 >= 15)
77
+
78
+
79
+ **Quantization levels for anxiety task:**
80
+
81
+ 0 – no anxiety (GAD-7 <= 4)
82
+
83
+ 1 – mild anxiety (5 <= GAD-7 <= 9)
84
+
85
+ 2 – moderate anxiety (10 <= GAD-7 <= 14)
86
+
87
+ 3 – severe anxiety (GAD-7 >= 15)
88
+
89
+ ## Intended Use
90
+ * Mental health research
91
+ * Clinical decision support
92
+ * Continuous monitoring of depression and anxiety
93
+
94
+ ## Limitations
95
+ * Not intended for diagnosis/self-diagnosis without clinical oversight
96
+ * Performance may degrade on speech recorded outside controlled environments or in the presence of noise
97
+ * Intended only for audio containing a single voice speaking English
98
+ * Biases related to language, accent, or demographic representation may be present
99
+
100
+
101
+ # Usage
102
+ 1. Checkout the repo:
103
+
104
+ ```
105
+ git clone https://huggingface.co/KintsugiHealth/dam
106
+ ```
107
+
108
+ 2. Install requirements:
109
+ ```python
110
+ pip install -r requirements.txt
111
+ ```
112
+
113
+ 3. Load and run pipeline
114
+ ```python
115
+ from pipeline import Pipeline
116
+
117
+ pipeline = Pipeline()
118
+ result = pipeline.run_on_file("sample.wav", quantized=True)
119
+ print(result)
120
+ ```
121
+ The output will resemble a dictionary, for example {'depression': 2, 'anxiety': 3}, indicating that the analyzed audio sample exhibits voice biomarkers consistent with severe depression and severe anxiety.
122
+
123
+ ## Tuning Thresholds
124
+ As mentioned in the Data section above, the raw audio data cannot be shared, but validation and test sets of model scores associated with ground truth and demographic metadata are available for threshold tuning. This way thresholds can be tuned for traditional binary classification, ternary classification with an indeterminate output, and multi-class classification of severity. Two modules are provided for this in the model code's `tuning` package, as illustrated below.
125
+
126
+ ### Tuning Sensitivity, Specificity, and Indeterminate Fraction
127
+ This module implements a generalization of ROC curve analysis wherein ground truth is binary, but model output can be negative (score below lower threshold), positive (score above upper threshold), or indeterminate (score between thresholds). For the purpose of metric calculations such as sensitivity and specificity, examples marked indeterminate do not count towards either the numerator or denominator. The budget for fraction of examples to be marked indeterminate is configurable as shown below.
128
+ ```
129
+ import numpy as np
130
+
131
+ from datasets import load_dataset
132
+ from tuning.indet_roc import BinaryLabeledScores
133
+
134
+ val = load_dataset("KintsugiHealth/dam-dataset", split="validation")
135
+ val.set_format("numpy")
136
+ test = load_dataset("KintsugiHealth/dam-dataset", split="test")
137
+ test.set_format("numpy")
138
+
139
+ data = dict(val=val, test=test)
140
+
141
+ # Associate depression model scores with binarized labels based on whether the PHQ-9 sum is >= 10
142
+ scores_labeled = {
143
+ k: BinaryLabeledScores(
144
+ y_score=v['scores_depression'], # Change to 'scores_anxiety' to calibrate anxiety thresholds
145
+ y_true=(v['phq'] >= 10).astype(int) # Change to 'gad' to calibrate anxiety thresholds; optionally change cutoff
146
+ )
147
+ for k, v in data.items()
148
+ }
149
+
150
+ issa = scores_labeled['val'].indet_sn_sp_array() # Metrics at all possible lower, upper threshold pairs
151
+
152
+ # Compute ROC curve with 20% indeterminate budget and select a point near the diagonal
153
+ roc_at_20 = issa.roc_curve(0.2) # Pareto frontier of (sensitivity, specificity) pairs with at most 20% indeterminate fraction
154
+ print(f"Area under the ROC curve with 20% indeterminate budget: {roc_at_20.auc()=:.1%}") #
155
+ sn_eq_sp_at_20 = roc_at_20.sn_eq_sp() # Find where ROC comes closest to sensitivity = specificity diagonal
156
+ print(f"Thresholds to balance sensitivity and specificity on val set with 20% indeterminate budget: "
157
+ f"{sn_eq_sp_at_20.lower_thresh=:.3}, {sn_eq_sp_at_20.upper_thresh=:.3}")
158
+ print(f"Performance on val set with these thresholds: {sn_eq_sp_at_20.sn=:.1%}, {sn_eq_sp_at_20.sp=:.1%}") #
159
+ test_metrics = sn_eq_sp_at_20.eval(**scores_labeled['test']._asdict()) # Thresholds evaluated on test set
160
+ print(f"Performance on test set with these thresholds: {test_metrics.sn=:.1%}, {test_metrics.sp=:.1%}") #
161
+
162
+ # Find best specificity given sensitivity and indeterminate budget constraints
163
+ constrained = issa[(issa.sn >= 0.8) & (issa.indet_frac <= 0.35)]
164
+ optimal = constrained[np.argmax(constrained.sp)]
165
+ print(f"Highest specificity achievable with sensitivity >= 80% and 35% indeterminate budget is "
166
+ f"{optimal.sp=:.1%}, achieved at thresholds {optimal.lower_thresh=:.3}, {optimal.upper_thresh=:.3}"
167
+ )
168
+
169
+ # Collect optimal ways of achieving balanced sensitivity and specificity as a function of indeterminate fraction
170
+ sn_eq_sp = issa.sn_eq_sp_graph()
171
+ ```
172
+
173
+ ### Optimal Tuning for Multiclass Tasks
174
+ The depression and anxiety models were each trained with ordinal regression to predict a scalar score monotonically correlated with the underlying PHQ-9 and GAD-7 questionnaire ground truth sums. As such there are efficient dynamic programming algorithms to select optimal thresholds for multi-class numeric labels under a variety of decision criteria.
175
+
176
+ ```
177
+ from datasets import load_dataset
178
+ from tuning.optimal_ordinal import MinAbsoluteErrorOrdinalThresholding
179
+
180
+ val = load_dataset("KintsugiHealth/dam-dataset", split="validation")
181
+ val.set_format("torch")
182
+ test = load_dataset("KintsugiHealth/dam-dataset", split="test")
183
+ test.set_format("torch")
184
+
185
+ data = dict(val=val, test=test)
186
+
187
+ scores = val['scores_anxiety'] # Change to 'scores_depression' for depression threshold tuning
188
+ labels = val['gad'] # Change to 'phq' for depression threshold tuning; optionally change to quantized version for coarser prediction tuning
189
+
190
+ # Can change to any of
191
+ # `MaxAccuracyOrdinalThresholding`
192
+ # `MaxMacroRecallOrdinalThresholding`
193
+ # `MaxMacroPrecisionOrdinalThresholding`
194
+ # `MaxMacroF1OrdinalThresholding`
195
+ optimal_thresh = MinAbsoluteErrorOrdinalThresholding(num_classes=int(labels.max()) + 1)
196
+ best_constant_cost, best_constant = optimal_thresh.best_constant_output_classifier(labels)
197
+ print(f"Always predicting GAD sum = {best_constant} on val set independent of model score gives mean absolute error {best_constant_cost:.3}.")
198
+ mean_error = optimal_thresh.tune_thresholds(labels=labels, scores=scores)
199
+ print(f"Thresholds optimized on val set to predict GAD sum from anxiety score: {optimal_thresh.thresholds}")
200
+ print(f"Mean absolute error predicting GAD sum on val set based on thresholds optimized on val set: {mean_error:.3}")
201
+ test_preds = optimal_thresh(test['scores_anxiety'])
202
+ mean_error_test = optimal_thresh.mean_cost(labels=test['gad'], preds=test_preds)
203
+ print(f"Mean absolute error predicting GAD sum on test set based on thresholds optimized on val set: {mean_error_test:.3}")
204
+ ```
205
+
206
+ # Acknowledgments
207
+
208
+ This model was created through equal contributions by Oleksii Abramenko, Noah Stein, and Colin Vaz while at Kintsugi Health. For a full list of contributors to earlier modeling projects, data collection, clinical, and business matters, see the organization card at https://huggingface.co/KintsugiHealth.
209
+
210
+ # References
211
+
212
+ 1. https://www.nimh.nih.gov/health/statistics/major-depression
213
+ 2. https://www.hopefordepression.org/depression-facts/
214
+ 3. https://nndc.org/facts/
215
+ 4. https://www.psychiatry.org/patients-families/stigma-and-discrimination
216
+ 5. https://www.sciencedirect.com/science/article/pii/S1746809423004536
217
+ 6. https://pmc.ncbi.nlm.nih.gov/articles/PMC3409931/
218
+ 7. https://pmc.ncbi.nlm.nih.gov/articles/PMC11559157
config.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for running Kintsugi Depression and Anxiety model."""
2
+
3
+ import torch
4
+
5
+ EXPECTED_SAMPLE_RATE = 16000 # Audio sample rate in hertz
6
+
7
+ # Configuration for running Kintsugi Depression and Anxiety model as intended
8
+ default_config = {
9
+ # See featex.py for preprocessor config details
10
+ 'preprocessor_config': {
11
+ 'normalize_features': True,
12
+ 'chunk_seconds': 30,
13
+ 'max_overlap_frac': 0.0,
14
+ 'pad_last_chunk_to_full': True,
15
+ },
16
+
17
+ # See model.py for backbone config details
18
+ 'backbone_configs': {'audio': {'model': 'openai/whisper-small.en',
19
+ 'hf_config': {'encoder_layerdrop': 0.0,
20
+ 'dropout': 0.0,
21
+ 'activation_dropout': 0.0},
22
+ 'lora_params': {'r': 32,
23
+ 'lora_alpha': 64.0,
24
+ 'target_modules': 'all-linear',
25
+ 'lora_dropout': 0.4,
26
+ 'modules_to_save': ['conv1', 'conv2'],
27
+ 'bias': 'all'}},
28
+ 'llma': {'model': 'openai/whisper-small.en',
29
+ 'hf_config': {'encoder_layerdrop': 0.0,
30
+ 'dropout': 0.0,
31
+ 'activation_dropout': 0.0}}},
32
+
33
+ # See model.py for classifier config details
34
+ 'classifier_config': {'shared_projection_dim': [256, 64],
35
+ 'tasks': {'depression': {'proj_dim': 128, 'dropout': 0.4},
36
+ 'anxiety': {'proj_dim': 128, 'dropout': 0.4}}},
37
+
38
+ # Score thresholds chosen to optimize macro average F1 score on validation set
39
+ 'inference_thresholds': {
40
+ # Three-level depression severity model:
41
+ # depression score <= -0.6699 --> no depression (PHQ-9 <= 9)
42
+ # -0.6699 < depression score <= -0.2908 --> mild to moderate depression (10 <= PHQ-9 <= 14)
43
+ # -0.2908 < depression score --> severe depression (PHQ-9 >= 15)
44
+ 'depression': [-0.6699, -0.2908],
45
+ # Four-level anxiety severity model:
46
+ # anxiety score <= -0.7939 --> no anxiety (GAD-7 <= 4)
47
+ # -0.7939 < anxiety score <= -0.2173 --> mild anxiety (5 <= GAD-7 <= 9)
48
+ # -0.2173 < anxiety score <= 0.1521 --> moderate anxiety (10 <= GAD-7 <= 14)
49
+ # 0.1521 < anxiety score --> severe anxiety (GAD-7 >= 15)
50
+ 'anxiety': [-0.7939, -0.2173, 0.1521]
51
+ }
52
+ }
53
+
54
+ # Average filter bank energies used for feature normalization
55
+ logmel_energies = torch.tensor([0.34912264, 0.58558977, 0.7912451 , 0.92767584, 0.98273695,
56
+ 0.98439455, 0.9603633 , 0.93906444, 0.9366281 , 0.93200225,
57
+ 0.916437 , 0.8928787 , 0.8637211 , 0.83265126, 0.79977655,
58
+ 0.7778334 , 0.7561299 , 0.72997606, 0.70391226, 0.6800474 ,
59
+ 0.65755 , 0.63536274, 0.61355984, 0.5923383 , 0.5720056 ,
60
+ 0.55244887, 0.53684795, 0.5221597 , 0.5098636 , 0.49923953,
61
+ 0.48908615, 0.47840047, 0.46758702, 0.47343993, 0.46268672,
62
+ 0.4475126 , 0.46747103, 0.45131385, 0.4635319 , 0.44889897,
63
+ 0.45491976, 0.4373785 , 0.43154317, 0.42194438, 0.41158468,
64
+ 0.40096927, 0.3933149 , 0.38795966, 0.38441542, 0.38454026,
65
+ 0.3815766 , 0.3768835 , 0.3719921 , 0.3654539 , 0.35399568,
66
+ 0.3425986 , 0.32823247, 0.31404305, 0.30564603, 0.29617435,
67
+ 0.29273877, 0.28560263, 0.27459458, 0.26876706, 0.25825337,
68
+ 0.24759005, 0.24090728, 0.2344712 , 0.22529823, 0.20880115,
69
+ 0.193578 , 0.18290243, 0.17621627, 0.17087021, 0.16641389,
70
+ 0.15932252, 0.14312662, 0.11790597, 0.08030523, 0.03747071],
71
+ )
dam3.1.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfa897e1b990de9377b2fb805b526a36fe6f31de01bbc8c3d288d317df2c4b0c
3
+ size 736180146
featex.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocessing and normalization to prepare audio for Kintsugi Depression and Anxiety model."""
2
+ from typing import Union, BinaryIO
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import torchaudio
7
+ from transformers import AutoFeatureExtractor
8
+
9
+ from config import EXPECTED_SAMPLE_RATE, logmel_energies
10
+
11
+
12
+ def load_audio(source: Union[BinaryIO, str, os.PathLike]) -> torch.Tensor:
13
+ """Load audio file, verify mono channel count, and resample if necessary.
14
+
15
+ Parameters
16
+ ----------
17
+ source: open file or path to file
18
+
19
+ Returns
20
+ -------
21
+ Time domain audio samples as a 1 x num_samples float tensor sampled at 16 kHz.
22
+
23
+ """
24
+ audio, fs = torchaudio.load(source)
25
+ if audio.shape[0] != 1:
26
+ raise ValueError(f"Provided audio has {audio.shape[0]} != 1 channels.")
27
+ if fs != EXPECTED_SAMPLE_RATE:
28
+ audio = torchaudio.functional.resample(audio, fs, EXPECTED_SAMPLE_RATE)
29
+ return audio
30
+
31
+
32
+ class Preprocessor:
33
+ def __init__(self,
34
+ normalize_features: bool = True,
35
+ chunk_seconds: int = 30,
36
+ max_overlap_frac: float = 0.0,
37
+ pad_last_chunk_to_full: bool = True,
38
+ ):
39
+ """Create preprocessor object.
40
+
41
+ Parameters
42
+ ----------
43
+ normalize_features: Whether the Whisper preprocessor should normalize features
44
+ chunk_seconds: Size of model's receptive field in seconds
45
+ max_overlap_frac: Fraction of each chunk allowed to overlap previous chunk for inputs longer than chunk_seconds
46
+ pad_last_chunk_to_full: Whether to pad audio to an integer multiple of chunk_seconds
47
+
48
+ """
49
+ self.preprocessor = AutoFeatureExtractor.from_pretrained("openai/whisper-small.en")
50
+ self.normalize_features = normalize_features
51
+ self.chunk_seconds = chunk_seconds
52
+ self.max_overlap_frac = max_overlap_frac
53
+ self.pad_last_chunk_to_full = pad_last_chunk_to_full
54
+
55
+ def preprocess_with_audio_normalization(
56
+ self,
57
+ audio: torch.Tensor,
58
+ ) -> torch.Tensor:
59
+ """Run Whisper preprocessor and normalization expected by the model.
60
+
61
+ Note: some normalization steps can be avoided, but are included to match
62
+ feature extraction used during training.
63
+
64
+ Parameters
65
+ ----------
66
+ audio: Raw audio samples as a 1 x num_samples float tensor sampled at 16 kHz
67
+
68
+ Returns
69
+ -------
70
+ Normalized mel filter bank features as a float tensor of shape
71
+ num_chunks x 80 mel filter bands x 3000 time frames
72
+
73
+ """
74
+ # Remove DC offset and scale amplitude to [-1, 1]
75
+ audio = torch.squeeze(audio, 0)
76
+ audio = audio - torch.mean(audio)
77
+ audio = audio / torch.max(torch.abs(audio))
78
+
79
+ chunk_samples = EXPECTED_SAMPLE_RATE * self.chunk_seconds
80
+
81
+ if self.pad_last_chunk_to_full:
82
+ # pad audio so that the last chunk is not dropped
83
+ if self.max_overlap_frac > 0:
84
+ raise ValueError(
85
+ f"pad_last_chunk_to_full is only supported for non-overlapping windows"
86
+ )
87
+ num_chunks = np.ceil(len(audio) / chunk_samples)
88
+ pad_size = int(num_chunks * chunk_samples - len(audio))
89
+ audio = torch.nn.functional.pad(audio, (0, pad_size))
90
+
91
+ overflow_len = len(audio) - chunk_samples
92
+
93
+ min_hop_samples = int(
94
+ (1 - self.max_overlap_frac) * chunk_samples
95
+ )
96
+
97
+ n_windows = 1 + overflow_len // min_hop_samples
98
+ window_starts = np.linspace(0, overflow_len, max(n_windows, 1)).astype(int)
99
+
100
+ features = self.preprocessor(
101
+ [
102
+ audio[start : start + chunk_samples].numpy(force=True)
103
+ for start in window_starts
104
+ ],
105
+ return_tensors="pt",
106
+ sampling_rate=EXPECTED_SAMPLE_RATE,
107
+ do_normalize=self.normalize_features,
108
+ )
109
+ for key in ("input_features", "input_values"):
110
+ if hasattr(features, key):
111
+ features = getattr(features, key)
112
+ break
113
+
114
+ mean_features = torch.mean(features, dim=-1)
115
+ # features are [batch, n_logmel_bins, n_frames]
116
+ rescale_factor = logmel_energies.unsqueeze(0) - mean_features
117
+ rescale_factor = rescale_factor.unsqueeze(2)
118
+ features += rescale_factor
119
+ return features
model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Mapping, Optional
2
+
3
+ import torch
4
+ from peft import LoraConfig, get_peft_model
5
+ from transformers import WhisperConfig, WhisperModel
6
+
7
+
8
+ class WhisperEncoderBackbone(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ model: str = "openai/whisper-small.en",
12
+ hf_config: Optional[Mapping[str, Any]] = None,
13
+ lora_params: Optional[Mapping[str, Any]] = None,
14
+ ):
15
+ """Whisper encoder model with optional Low-Rank Adaptation.
16
+
17
+ Parameters
18
+ ----------
19
+ model: Name of WhisperModel whose encoder to load from HuggingFace
20
+ hf_config: Optional config for HuggingFace model
21
+ lora_params: Parameters for Low-Rank Adaptation
22
+
23
+ """
24
+ super().__init__()
25
+ hf_config = hf_config if hf_config is not None else dict()
26
+ backbone_config = WhisperConfig.from_pretrained(model, **hf_config)
27
+ self.backbone = (
28
+ WhisperModel.from_pretrained(
29
+ model,
30
+ config=backbone_config,
31
+ )
32
+ .get_encoder()
33
+ .train()
34
+ )
35
+ if lora_params is not None and len(lora_params) > 0:
36
+ lora_config = LoraConfig(**lora_params)
37
+ self.backbone = get_peft_model(self.backbone, lora_config)
38
+ self.backbone_dim = backbone_config.hidden_size
39
+
40
+ def forward(self, whisper_feature_batch):
41
+ return self.backbone(whisper_feature_batch).last_hidden_state.mean(dim=1)
42
+
43
+
44
+ class SharedLayers(torch.nn.Module):
45
+ def __init__(self, input_dim: int, proj_dims: list[int]):
46
+ """Fully connected network with Mish nonlinearities between linear layers. No nonlinearity at input or output.
47
+
48
+ Parameters
49
+ ----------
50
+ input_dim: Dimension of input features
51
+ proj_dims: Dimensions of layers to create
52
+
53
+ """
54
+ super().__init__()
55
+ modules = []
56
+ for output_dim in proj_dims[:-1]:
57
+ modules.extend([torch.nn.Linear(input_dim, output_dim), torch.nn.Mish()])
58
+ input_dim = output_dim
59
+ modules.append(torch.nn.Linear(input_dim, proj_dims[-1]))
60
+ self.shared_layers = torch.nn.Sequential(*modules)
61
+
62
+ def forward(self, x):
63
+ return self.shared_layers(x)
64
+
65
+ class TaskHead(torch.nn.Module):
66
+ def __init__(self, input_dim: int, proj_dim: int, dropout: float = 0.0):
67
+ """Fully connected network with one hidden layer, dropout, and a scalar output."""
68
+ super().__init__()
69
+
70
+ self.linear = torch.nn.Linear(input_dim, proj_dim)
71
+ self.activation = torch.nn.Mish()
72
+ self.dropout = torch.nn.Dropout(dropout)
73
+ self.final_layer = torch.nn.Linear(proj_dim, 1, bias=False)
74
+
75
+ def forward(self, x):
76
+ x = self.linear(x)
77
+ x = self.activation(x)
78
+ x = self.dropout(x)
79
+ x = self.final_layer(x)
80
+ return x
81
+
82
+
83
+ class MultitaskHead(torch.nn.Module):
84
+ def __init__(
85
+ self,
86
+ backbone_dim: int,
87
+ shared_projection_dim: list[int],
88
+ tasks: Mapping[str, Mapping[str, Any]],
89
+ ):
90
+ """Fully connected network with multiple named scalar outputs."""
91
+ super().__init__()
92
+
93
+ # Initialize the shared network and task-specific networks
94
+ self.shared_layers = SharedLayers(backbone_dim, shared_projection_dim)
95
+ self.classifier_head = torch.nn.ModuleDict(
96
+ {
97
+ task: TaskHead(shared_projection_dim[-1], **task_config)
98
+ for task, task_config in tasks.items()
99
+ }
100
+ )
101
+
102
+ def forward(self, x):
103
+ x = self.shared_layers(x)
104
+ return {task: head(x) for task, head in self.classifier_head.items()}
105
+
106
+
107
+ def average_tensor_in_segments(tensor: torch.Tensor, lengths: list[int] | torch.Tensor):
108
+ """Average segments of a `tensor` along dimension 0 based on a list of `lengths`
109
+
110
+ For example, with input tensor `t` and `lengths` [1, 3, 2], the output would be
111
+ [t[0], (t[1] + t[2] + t[3]) / 3, (t[4] + t[5]) / 2]
112
+
113
+ Parameters
114
+ ----------
115
+ tensor : torch.Tensor
116
+ The tensor to average
117
+ lengths : list of ints
118
+ The lengths of each segment to average in the tensor, in order
119
+
120
+ Returns
121
+ -------
122
+ torch.Tensor
123
+ The tensor with relevant segments averaged
124
+ """
125
+ if not torch.is_tensor(lengths):
126
+ lengths = torch.tensor(lengths, device=tensor.device)
127
+ index = torch.repeat_interleave(
128
+ torch.arange(len(lengths), device=tensor.device), lengths
129
+ )
130
+ out = torch.zeros(
131
+ lengths.shape + tensor.shape[1:], device=tensor.device, dtype=tensor.dtype
132
+ )
133
+ out.index_add_(0, index, tensor)
134
+ broadcastable_lengths = lengths.view((-1,) + (1,) * (len(out.shape) - 1))
135
+ return out / broadcastable_lengths
136
+
137
+
138
+ class Classifier(torch.nn.Module):
139
+ def __init__(
140
+ self,
141
+ backbone_configs: Mapping[str, Mapping[str, Any]],
142
+ classifier_config: Mapping[str, Any],
143
+ inference_thresholds: Mapping[str, Any],
144
+ preprocessor_config: Mapping[str, Any],
145
+ ):
146
+ """Full Kintsugi Depression and Anxiety model.
147
+
148
+ Whisper encoder -> Mean pooling over time -> Layers shared across tasks -> Per-task heads
149
+
150
+ Parameters
151
+ ----------
152
+ backbone_configs:
153
+ classifier_config:
154
+ inference_thresholds:
155
+ preprocessor_config:
156
+
157
+ """
158
+ super().__init__()
159
+
160
+ self.backbone = torch.nn.ModuleDict(
161
+ {
162
+ key: WhisperEncoderBackbone(**backbone_configs[key])
163
+ for key in sorted(backbone_configs.keys())
164
+ }
165
+ )
166
+
167
+ backbone_dim = sum(layer.backbone_dim for layer in self.backbone.values())
168
+ self.head = MultitaskHead(backbone_dim, **classifier_config)
169
+ self.inference_thresholds = inference_thresholds
170
+ self.preprocessor_config = preprocessor_config
171
+
172
+ def forward(self, x, lengths):
173
+ backbone_outputs = {
174
+ key: average_tensor_in_segments(layer(x), lengths)
175
+ for key, layer in self.backbone.items()
176
+ }
177
+ backbone_output = torch.cat(list(backbone_outputs.values()), dim=1)
178
+ return self.head(backbone_output), torch.ones_like(lengths)
179
+
180
+ def quantize_scores(self, scores: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
181
+ """Map per-task scores to discrete predictions per `inference_thresholds` config."""
182
+ return {
183
+ key: torch.searchsorted(torch.tensor(self.inference_thresholds[key], device=value.device), value.mean(), out_int32=True)
184
+ for key, value in scores.items()
185
+ }
pipeline.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Any, BinaryIO, Mapping, Optional, Union
4
+ import torch
5
+
6
+ from config import default_config
7
+ from featex import load_audio, Preprocessor
8
+ from model import Classifier
9
+
10
+ class Pipeline:
11
+ def __init__(self, checkpoint: Optional[str | Path] = None, config: Optional[Mapping[str, Any]] = None, device: Optional[torch.device] = None):
12
+ if checkpoint is None:
13
+ file_dir = Path(__file__).parent.resolve()
14
+ checkpoint = file_dir / "dam3.1.ckpt"
15
+ if config is None:
16
+ config = default_config
17
+ if device is None:
18
+ if torch.cuda.is_available():
19
+ device = torch.device("cuda:0")
20
+ else:
21
+ device = torch.device("cpu")
22
+ self.device = device
23
+ self.model = Classifier(**config)
24
+ self.preprocessor = Preprocessor(**self.model.preprocessor_config)
25
+ state_dict = torch.load(checkpoint, map_location=device)
26
+ self.model.load_state_dict(state_dict)
27
+ self.model.to(self.device)
28
+ self.model.eval()
29
+
30
+ def run_on_features(self, features: torch.Tensor, quantize: bool = True):
31
+ scores = self.model(features, torch.tensor([features.shape[0]], device=self.device))[0]
32
+ if quantize:
33
+ return {k: int(v.item()) for k, v in self.model.quantize_scores(scores).items()}
34
+ else:
35
+ return scores
36
+
37
+ def run_on_audio(self, audio: torch.Tensor, quantize: bool = True):
38
+ features = self.preprocessor.preprocess_with_audio_normalization(audio)
39
+ return self.run_on_features(features.to(self.device), quantize=quantize)
40
+
41
+ def run_on_file(self, source: Union[BinaryIO, str, os.PathLike], quantize=True):
42
+ audio = load_audio(source)
43
+ return self.run_on_audio(audio, quantize=quantize)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pytorch~=2.6.0
2
+ pysoundfile~=0.13.1
3
+ torchaudio~=2.7.0
4
+ transformers~=4.52.3
5
+ peft~=0.15.2
tuning/__init__.py ADDED
File without changes
tuning/indet_roc.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tools for tuning pairs of scalar thresholds to trade off sensitivity, specificity, and indeterminate rate.
2
+
3
+ See `IndetSnSpArray` and subclass docstrings for details.
4
+
5
+ """
6
+ from dataclasses import asdict, dataclass
7
+ from typing import NamedTuple, Optional
8
+ from typing_extensions import Self # in typing in python3.11
9
+
10
+ import numpy as np
11
+
12
+
13
+ def running_argmax_indices(a):
14
+ """Return indices of a where the value is larger than all previous values.
15
+
16
+ >>> running_argmax_indices([1, 0, 3, 4, 4, 2, 5, 7, 1])
17
+ array([0, 2, 3, 6, 7])
18
+
19
+ """
20
+ m = np.maximum.accumulate(a)
21
+ return np.flatnonzero(np.r_[True, m[:-1] < m[1:]])
22
+
23
+
24
+ def pareto_2d_indices(x, y):
25
+ """Compute indices of the Pareto frontier maximizing x and y, sorted in increasing x and decreasing y.
26
+
27
+ e.g. the Pareto frontier of the point set below is [A, G]
28
+
29
+ B A
30
+ C
31
+ E D
32
+ F G
33
+ H
34
+
35
+ >>> u = [2, 0, 1, 2, 0, 0, 3, 1]
36
+ >>> v = [4, 4, 3, 2, 2, 1, 1, 0]
37
+ >>> pareto_2d_indices(np.array(u), np.array(v))
38
+ array([0, 6])
39
+
40
+ """
41
+ sort_indices = np.lexsort((-x, -y)) # last element is primary sort key
42
+ return sort_indices[running_argmax_indices(x[sort_indices])]
43
+
44
+
45
+ def midpoints_with_infs(x):
46
+ """Return the midpoints between the sorted unique elements of x, along with +/-inf."""
47
+ unique_scores = np.unique(np.r_[-np.inf, x, np.inf])
48
+ return (unique_scores[1:] + unique_scores[:-1]) / 2
49
+
50
+
51
+ def kde_disc_mass(
52
+ data: np.ndarray,
53
+ edges: np.ndarray,
54
+ bandwidth: float,
55
+ weights: Optional[np.ndarray] = None,
56
+ ):
57
+ """Perform Kernel Density Estimation (KDE) on data & weights, then compute the probability mass between edges."""
58
+ import scipy
59
+ z_score = (edges[:, None] - data[None, :]) / bandwidth
60
+ component_cdfs = scipy.stats.norm.cdf(z_score)
61
+ if weights is None:
62
+ weights = np.ones_like(data)
63
+ cdf = np.dot(component_cdfs, weights / weights.sum())
64
+ return np.diff(cdf)
65
+
66
+
67
+ class BinaryLabeledScores(NamedTuple):
68
+ """An array of numeric scores along with associated 0/1 ground truth and optional numeric weights."""
69
+
70
+ y_score: np.ndarray
71
+ y_true: np.ndarray
72
+ weights: Optional[np.ndarray] = None
73
+
74
+ def smooth(
75
+ self, num_points: int, bandwidth: float, padding_bandwidths: float = 5.0
76
+ ) -> "BinaryLabeledScores":
77
+ """KDE-smooth positive and negative scores separately and discretize each to equally spaced weighted points.
78
+
79
+ Args:
80
+ num_points: number of points to use each for positive and negative score discretizations
81
+ bandwidth: bandwidth of kernel density estimation, i.e. standard deviation of noise to be added
82
+ padding_bandwidths: number of bandwidths to extend past lowest and highest scores when selecting
83
+ discretization endpoints
84
+
85
+ Returns:
86
+ `BinaryLabeledScores` object representing the smoothed and re-discretized weighted labeled scores
87
+
88
+ """
89
+ pos = self.y_true == 1
90
+ neg = self.y_true == 0
91
+ if self.weights is not None:
92
+ pos_weights = self.weights[pos]
93
+ neg_weights = self.weights[neg]
94
+ else:
95
+ pos_weights = None
96
+ neg_weights = None
97
+ padding = padding_bandwidths * bandwidth
98
+ all_points = np.linspace(
99
+ self.y_score.min() - padding,
100
+ self.y_score.max() + padding,
101
+ 2 * num_points + 1,
102
+ )
103
+ edges = all_points[::2]
104
+ centers = all_points[1::2]
105
+ pos_kde_weights = kde_disc_mass(
106
+ self.y_score[pos], edges, bandwidth, pos_weights
107
+ )
108
+ neg_kde_weights = kde_disc_mass(
109
+ self.y_score[neg], edges, bandwidth, neg_weights
110
+ )
111
+ return BinaryLabeledScores(
112
+ y_true=np.r_[np.zeros_like(centers), np.ones_like(centers)],
113
+ y_score=np.r_[centers, centers],
114
+ weights=np.r_[neg_kde_weights, pos_kde_weights],
115
+ )
116
+
117
+ def indet_sn_sp_array(self) -> "IndetSnSpArray":
118
+ """Build `IndetSnSpArray`."""
119
+ return IndetSnSpArray.build(**self._asdict())
120
+
121
+
122
+ def fake_vectorized_binom_ci(
123
+ k: np.ndarray, n: np.ndarray, p: float | np.ndarray = 0.95
124
+ ) -> tuple[np.ndarray, np.ndarray]:
125
+ """Compute binomial confidence intervals on arrays of parameters inefficiently."""
126
+ import scipy
127
+ k, n, p = np.broadcast_arrays(k, n, p)
128
+ # If speed is needed this can be rewritten with the statsmodels package, which is vectorized.
129
+ flat_out = [
130
+ scipy.stats.binomtest(k_, n_).proportion_ci(p_)
131
+ for k_, n_, p_ in zip(k.flatten(), n.flatten(), p.flatten())
132
+ ]
133
+ low = np.array([ci.low for ci in flat_out]).reshape(k.shape)
134
+ high = np.array([ci.high for ci in flat_out]).reshape(k.shape)
135
+ return low, high
136
+
137
+
138
+ @dataclass
139
+ class IndetSnSpArray:
140
+ """An array of metrics at different lower and upper threshold values.
141
+
142
+ This class and subclasses are for selecting pairs of model thresholds based on sensitivity,
143
+ specificity, and indeterminate rate. Throughout we assume scores are scalars and ground truth is binary.
144
+
145
+ Each member `lower_thresh`, `upper_thresh`, `sn`, `sp`, and `indet_frac` must be a numpy array, and they all must
146
+ have the same shape. Corresponding entries of these arrays specify a pair of thresholds and the metrics when a
147
+ common dataset is evaluated using those thresholds. The thresholding logic is that scores less than the lower
148
+ threshold count as negative outputs, scores greater than or equal to the upper threshold count as positive outputs,
149
+ and scores in between are indeterminate outputs.
150
+
151
+ Indeterminate fraction is defined as the proportion of scores in between the two thresholds. All other
152
+ metrics are interpreted as conditioned on the scores not being indeterminate. For example, sensitivity
153
+ is defined as usual as (true positives) / (total positives) *except that examples with indeterminate
154
+ scores do not count towards the numerator or the denominator*.
155
+
156
+ """
157
+
158
+ lower_thresh: np.ndarray
159
+ upper_thresh: np.ndarray
160
+ tp: np.ndarray
161
+ fp: np.ndarray
162
+ tn: np.ndarray
163
+ fn: np.ndarray
164
+ indet: np.ndarray
165
+ weighted: bool = False
166
+ min_weight: float = 1.0
167
+ eps: float = 1e-8
168
+
169
+ def __post_init__(self):
170
+ self.sn = self.tp / np.maximum(self.tp + self.fn, self.min_weight)
171
+ self.sp = self.tn / np.maximum(self.tn + self.fp, self.min_weight)
172
+ self.ppv = self.tp / np.maximum(self.tp + self.fp, self.min_weight)
173
+ self.npv = self.tn / np.maximum(self.tn + self.fn, self.min_weight)
174
+ total = self.indet + self.fn + self.fp + self.tn + self.tp
175
+ self.indet_frac = self.indet / np.maximum(total, self.min_weight)
176
+ for attr in ("sn", "sp", "ppv", "npv", "indet_frac"):
177
+ value = getattr(self, attr)
178
+ if value.size:
179
+ if value.max() > 1 + self.eps:
180
+ raise ValueError(
181
+ f"Numerical precision issues produced invalid value {attr} = {value.max()}."
182
+ )
183
+ if value.min() < -self.eps:
184
+ raise ValueError(
185
+ f"Numerical precision issues produced invalid value {attr} = {value.min()}."
186
+ )
187
+ setattr(self, attr, np.clip(value, 0.0, 1.0))
188
+
189
+ @property
190
+ def min_sn_sp(self):
191
+ return np.minimum(self.sn, self.sp)
192
+
193
+ @classmethod
194
+ def build(
195
+ cls,
196
+ lower_thresh: Optional[np.ndarray] = None,
197
+ upper_thresh: Optional[np.ndarray] = None,
198
+ *,
199
+ y_true: np.ndarray,
200
+ y_score: np.ndarray,
201
+ weights: Optional[np.ndarray] = None,
202
+ eps: float = 1e-8,
203
+ ) -> Self:
204
+ """Find `IndetSnSpArray` values for given truth and scores as thresholds vary (à la sklearn.metrics.roc_curve).
205
+
206
+ The output object contains arrays for `sn`, `sp`, `indet_frac`, `lower_thresh`, and `upper_thresh`, all with the
207
+ same shape. What these arrays contain and what their common shape is depends on the input as follows.
208
+
209
+ If both lower_thresh and upper_thresh are provided, they must have the same shape and this method computes
210
+ metrics at the pairs given by corresponding entries in these arrays. The common output shape will be the same as
211
+ this common input shape.
212
+
213
+ If only one set of thresholds is provided, this method computes metrics at all sorted pairs of these thresholds
214
+ (along with +/- inf). If neither is provided, sort scores and allow thresholds between each pair (along with
215
+ +/- inf). In both of these cases, the common output shape is a 1-d vector of length equal to the number of such
216
+ pairs.
217
+
218
+ """
219
+ weights = weights if weights is not None else np.ones_like(y_true)
220
+ y_true = y_true[weights > 0]
221
+ y_score = y_score[weights > 0]
222
+ weights = weights[weights > 0]
223
+
224
+ # Find all threshes and include +/- inf so np.histogram does the right thing
225
+ if lower_thresh is not None and upper_thresh is not None:
226
+ threshes = np.unique(np.r_[-np.inf, lower_thresh, upper_thresh, np.inf])
227
+ lower_indices = np.searchsorted(threshes, lower_thresh)
228
+ upper_indices = np.searchsorted(threshes, upper_thresh)
229
+ else:
230
+ if lower_thresh is not None:
231
+ threshes = np.unique(np.r_[-np.inf, lower_thresh, np.inf])
232
+ elif upper_thresh is not None:
233
+ threshes = np.unique(np.r_[-np.inf, upper_thresh, np.inf])
234
+ else:
235
+ unique_scores = np.unique(np.r_[-np.inf, y_score, np.inf])
236
+ threshes = (unique_scores[1:] + unique_scores[:-1]) / 2
237
+ threshes = np.unique(threshes)
238
+ lower_indices, upper_indices = np.triu_indices(len(threshes))
239
+
240
+ count_by_bin = np.histogram(y_score, bins=threshes, weights=weights)[0]
241
+ pos_by_bin = np.histogram(y_score, bins=threshes, weights=y_true * weights)[0]
242
+ count_by_thresh = np.pad(np.cumsum(count_by_bin), (1, 0))
243
+ pos_by_thresh = np.pad(np.cumsum(pos_by_bin), (1, 0))
244
+ tn_plus_fn = count_by_thresh[lower_indices]
245
+ total_minus_tp_minus_fp = count_by_thresh[upper_indices]
246
+ tp_plus_fp = count_by_thresh[-1] - total_minus_tp_minus_fp
247
+ fn = pos_by_thresh[lower_indices]
248
+ total_pos = pos_by_thresh[-1] # last thresh is +inf
249
+ tp = total_pos - pos_by_thresh[upper_indices]
250
+ fp = tp_plus_fp - tp
251
+ tn = tn_plus_fn - fn
252
+ min_weight = weights.min()
253
+ indet = total_minus_tp_minus_fp - tn_plus_fn
254
+ return cls(
255
+ lower_thresh=threshes[lower_indices],
256
+ upper_thresh=threshes[upper_indices],
257
+ tp=tp,
258
+ fp=fp,
259
+ tn=tn,
260
+ fn=fn,
261
+ indet=indet,
262
+ weighted=not all(weights == 1.0),
263
+ min_weight=min_weight,
264
+ eps=eps,
265
+ )
266
+
267
+ def eval(
268
+ self,
269
+ *,
270
+ y_true,
271
+ y_score,
272
+ weights: Optional[np.ndarray] = None,
273
+ ) -> "IndetSnSpArray":
274
+ """Evaluate the given data on the thresholds of `self`."""
275
+ return IndetSnSpArray.build(
276
+ lower_thresh=self.lower_thresh,
277
+ upper_thresh=self.upper_thresh,
278
+ y_true=y_true,
279
+ y_score=y_score,
280
+ weights=weights,
281
+ )
282
+
283
+ def __getitem__(self, item) -> "IndetSnSpArray":
284
+ """Extract a subarray with numpy-style indexing."""
285
+ return IndetSnSpArray(
286
+ lower_thresh=self.lower_thresh[item],
287
+ upper_thresh=self.upper_thresh[item],
288
+ tp=self.tp[item],
289
+ fp=self.fp[item],
290
+ fn=self.fn[item],
291
+ tn=self.tn[item],
292
+ indet=self.indet[item],
293
+ weighted=self.weighted,
294
+ min_weight=self.min_weight,
295
+ eps=self.eps,
296
+ )
297
+
298
+ def __add__(self, other: "IndetSnSpArray") -> "IndetSnSpArray":
299
+ if not isinstance(other, IndetSnSpArray):
300
+ raise TypeError(f"Cannot add {type(other)} to IndetSnSpArray.")
301
+ tp = self.tp + other.tp
302
+ if np.array_equal(self.lower_thresh, other.lower_thresh):
303
+ lower_thresh = self.lower_thresh
304
+ else:
305
+ lower_thresh = np.nan * np.ones_like(tp)
306
+ if np.array_equal(self.upper_thresh, other.upper_thresh):
307
+ upper_thresh = self.upper_thresh
308
+ else:
309
+ upper_thresh = np.nan * np.ones_like(tp)
310
+ return IndetSnSpArray(
311
+ lower_thresh=lower_thresh,
312
+ upper_thresh=upper_thresh,
313
+ tp=tp,
314
+ fp=self.fp + other.fp,
315
+ fn=self.fn + other.fn,
316
+ tn=self.tn + other.tn,
317
+ indet=self.indet + other.indet,
318
+ weighted=self.weighted or other.weighted,
319
+ min_weight=min(self.min_weight, other.min_weight),
320
+ eps=self.eps,
321
+ )
322
+
323
+ def confidence_interval_bound(self, p: float = 0.95) -> "IndetSnSpArray":
324
+ """Compute two-sided confidence interval bounds: upper for indet_frac and lower for other metrics."""
325
+ if self.weighted:
326
+ raise NotImplementedError(
327
+ "Confidence intervals only implemented for unweighted confusion matrices."
328
+ )
329
+ copy = IndetSnSpArray(**asdict(self))
330
+ copy.sn, _ = fake_vectorized_binom_ci(self.tp, self.tp + self.fn, p=p)
331
+ copy.sp, _ = fake_vectorized_binom_ci(self.tn, self.tn + self.fp, p=p)
332
+ copy.ppv, _ = fake_vectorized_binom_ci(self.tp, self.fp + self.tp, p=p)
333
+ copy.npv, _ = fake_vectorized_binom_ci(self.tn, self.tn + self.fn, p=p)
334
+ _, copy.indet_frac = fake_vectorized_binom_ci(
335
+ self.indet, self.tp + self.fn + self.fp + self.tn + self.indet, p=p
336
+ )
337
+ return copy
338
+
339
+ def roc_curve(self, indet_budget=0.0) -> "IndetRocCurve":
340
+ """Compute ROC curve with indeterminate budget, sorted by increasing sn and decreasing sp.
341
+
342
+ Restrict `self` to Pareto-optimal pairs (sn, sp) for which `indet_frac <= indet_budget`. Other points are worse
343
+ than the points on the curve in the sense of having worse Sn, worse Sp, or not meeting the indeterminate budget.
344
+
345
+ """
346
+ within_budget = self[self.indet_frac <= indet_budget]
347
+ frontier = pareto_2d_indices(within_budget.sn, within_budget.sp)
348
+ return IndetRocCurve(**asdict(within_budget[frontier]))
349
+
350
+ def sn_eq_sp_graph(self) -> "IndetSnEqSpGraph":
351
+ """Compute sn=sp as a function of indet_frac, returning both sorted in increasing order.
352
+
353
+ Method: restrict to Pareto-optimal pairs (s, indet_frac) where s = min(sn, sp).
354
+
355
+ Pareto-optimality means that if (s, indet_frac) is in the output, there is no point (sn', sp', indet_frac') in
356
+ the input with indet_frac' <= indet_frac and sn', sp' > s. In other words, s is the maximum value such that
357
+ the quadrant { sn, sp >= s} intersects `self.roc_curve(indet_frac)`. This maximum occurs where the ROC curve
358
+ intersects the diagonal, up to an error bounded by the distance between points on the ROC curve.
359
+
360
+ """
361
+ frontier = pareto_2d_indices(self.min_sn_sp, -self.indet_frac)
362
+ return IndetSnEqSpGraph(**asdict(self[frontier]))
363
+
364
+
365
+ class IndetRocCurve(IndetSnSpArray):
366
+ """Sn, Sp achievable within some indeterminate budget and associated lower and upper thresholds.
367
+
368
+ `sn` is assumed to be sorted in increasing order and `sp` decreasing.
369
+
370
+ """
371
+
372
+ def sn_eq_sp(self) -> IndetSnSpArray:
373
+ """Locate the point on the ROC curve closest to the diagonal"""
374
+ return self[np.argmax(self.min_sn_sp)]
375
+
376
+ def auc(self) -> float:
377
+ """Compute the area under the ROC curve."""
378
+ # `auc` does not automatically include the trivial points (0, 1) and (1, 0)
379
+ # and will underestimate the AUC if these are not explicitly added
380
+ from sklearn.metrics import auc
381
+ return auc(1 - np.r_[1.0, self.sp, 0.0], np.r_[0.0, self.sn, 1.0])
382
+
383
+ @classmethod
384
+ def build(
385
+ cls,
386
+ thresh=None,
387
+ *,
388
+ y_true: np.ndarray,
389
+ y_score: np.ndarray,
390
+ weights: Optional[np.ndarray] = None,
391
+ ) -> Self:
392
+ """Build an indeterminate=0 ROC curve in n log n time (vs n**2 for IndetSnSpArray.build().roc_curve())."""
393
+ if thresh is None:
394
+ thresh = midpoints_with_infs(y_score)[
395
+ ::-1
396
+ ] # reverse for proper output sorting
397
+ issa = IndetSnSpArray.build(
398
+ lower_thresh=thresh,
399
+ upper_thresh=thresh,
400
+ y_true=y_true,
401
+ y_score=y_score,
402
+ weights=weights,
403
+ )
404
+ return cls(**asdict(issa))
405
+
406
+
407
+ class IndetSnEqSpGraph(IndetSnSpArray):
408
+ """Sn=Sp achievable as a function of indeterminate budget and associated lower and upper thresholds.
409
+
410
+ Both min(self.sn, self.sp) and self.indet_frac are assumed to be sorted in non-decreasing order.
411
+
412
+ """
413
+
414
+ def at_budget(self, indet_budget: float = 0.0) -> IndetSnSpArray:
415
+ """Locate the best point on the graph within the given budget."""
416
+ return self[np.searchsorted(self.indet_frac, indet_budget, side="right") - 1]
tuning/optimal_ordinal.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tools for choosing multiple thresholds optimally under various decision criteria via dynamic programming.
2
+
3
+ The abstract machinery is contained in the classes:
4
+ - `OrdinalThresholding`
5
+ - `OptimalOrdinalThresholdingViaDynamicProgramming`
6
+ - `OptimalCostPerSampleOrdinalThresholding`
7
+ - `ClassWeightedOptimalCostPerSampleOrdinalThresholding`
8
+ - `OptimalCostPerClassOrdinalThresholding`
9
+ These can be subclassed to efficiently implement new decision criteria, depending on their structure.
10
+
11
+ The main intended user-facing classes are the subclasses implementing different decision criteria:
12
+ - `MaxAccuracyOrdinalThresholding`
13
+ - `MaxMacroRecallOrdinalThresholding`
14
+ - `MinAbsoluteErrorOrdinalThresholding`
15
+ - `MaxMacroPrecisionOrdinalThresholding`
16
+ - `MaxMacroF1OrdinalThresholding`
17
+
18
+ """
19
+
20
+ from abc import ABC, abstractmethod
21
+ from typing import Literal, Optional, Union
22
+
23
+ import torch
24
+
25
+
26
+ class OrdinalThresholding(torch.nn.Module):
27
+ """Basic 1d thresholding logic."""
28
+
29
+ def __init__(self, num_classes: int):
30
+ """Init thresholding module with the specified number of classes (one more than the number of thresholds)."""
31
+ super().__init__()
32
+ self.num_classes = num_classes
33
+ self.register_buffer("thresholds", torch.zeros(num_classes - 1))
34
+ self.thresholds: torch.Tensor
35
+
36
+ def is_valid(self) -> bool:
37
+ """Check whether the thresholds are monotone non-decreasing."""
38
+ return all(torch.greater_equal(self.thresholds[1:], self.thresholds[:-1]))
39
+
40
+ def forward(self, scores) -> torch.Tensor:
41
+ """Find which thresholds each score lies between."""
42
+ return torch.searchsorted(self.thresholds, scores)
43
+
44
+ def tune_thresholds(
45
+ self,
46
+ *,
47
+ scores: torch.Tensor,
48
+ labels: torch.Tensor,
49
+ available_thresholds: Optional[torch.Tensor] = None,
50
+ ) -> torch.Tensor:
51
+ """Adapt the thresholds to the given data.
52
+
53
+ This is essentially an abstract method, but for testing purposes it's helpful to be able to instantiate the
54
+ class with a no-op version.
55
+
56
+ Parameters
57
+ ----------
58
+ scores : a vector of `float` scores for each example in the validation set
59
+ labels : a vector of `int` labels having the same shape as `scores` containing the corresponding labels
60
+ available_thresholds : a vector of `float` score values over which to optimize choice of thresholds;
61
+ `None`, then thresholds between every score in the validation set are allowed. +/- inf are always allowed.
62
+
63
+ Returns
64
+ -------
65
+ scalar `float` mean cost on the validation set using optimal thresholds
66
+
67
+ """
68
+
69
+
70
+ class OptimalOrdinalThresholdingViaDynamicProgramming(OrdinalThresholding, ABC):
71
+ """Super-class for general dynamic programming implementations of ordinal threshold tuning.
72
+
73
+ Subclasses implement different ways of computing the mean cost and corresponding DP step.
74
+
75
+ """
76
+
77
+ direction: Literal["min", "max"] # provided by subclasses
78
+
79
+ def __init__(self, num_classes: int):
80
+ super().__init__(num_classes=num_classes)
81
+ if self.direction not in ("min", "max"):
82
+ raise ValueError(
83
+ f"Got direction {self.direction!r}, expected 'min' or 'max'."
84
+ )
85
+
86
+ @abstractmethod
87
+ def mean_cost(
88
+ self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]
89
+ ) -> torch.Tensor:
90
+ """Compute the mean cost of assigning label(s) `preds` when the ground truth is `labels`."""
91
+
92
+ def best_constant_output_classifier(self, labels: torch.Tensor):
93
+ """Find the optimal mean cost of a constant-output classifier for given `labels` and the associated constant."""
94
+ if self.direction == "min":
95
+ optimize = torch.min
96
+ else:
97
+ optimize = torch.max
98
+ optimum = optimize(
99
+ torch.tensor(
100
+ [
101
+ self.mean_cost(labels=labels, preds=c)
102
+ for c in range(self.num_classes)
103
+ ],
104
+ device=labels.device,
105
+ ),
106
+ 0,
107
+ )
108
+ return optimum.values, optimum.indices
109
+
110
+ @abstractmethod
111
+ def dp_step(
112
+ self,
113
+ c_idx: int,
114
+ *,
115
+ scores: torch.Tensor,
116
+ labels: torch.Tensor,
117
+ available_thresholds: torch.Tensor,
118
+ prev_cost: Optional[torch.Tensor] = None,
119
+ ) -> (torch.Tensor, Optional[torch.Tensor]):
120
+ """Given optimal cost `prev_cost` of classes < `c_idx`, optimize cost of `c_idx` as a function of threshold.
121
+
122
+ Arguments
123
+ ---------
124
+ c_idx : current class index
125
+ scores, labels, available_thresholds : see `tune_thresholds`
126
+ prev_cost (optional float tensor) : optimal cost of classes < `c_idx` as a function of upper threshold
127
+ for class `c_idx - 1`; ignored if `c_idx == 0`
128
+
129
+ Returns
130
+ -------
131
+ cost: `cost[i]` is for choosing upper threshold of class `c_idx` equal to `available_thresholds[i]`
132
+ when thresholds for lower classes are chosen optimally
133
+ indices : to achieve `cost[i]`, optimal upper threshold for class `c_idx - 1` is
134
+ `available_thresholds[indices[i]]`; `None` if `c_idx == 0`
135
+
136
+ """
137
+
138
+ def tune_thresholds(
139
+ self,
140
+ *,
141
+ scores: torch.Tensor,
142
+ labels: torch.Tensor,
143
+ available_thresholds: Optional[torch.Tensor] = None,
144
+ ) -> torch.Tensor:
145
+ """Set `self.thresholds` to optimize mean cost of given `scores` and `labels`.
146
+
147
+ Arguments
148
+ ---------
149
+ scores (1d float tensor) : scores of examples on tuning dataset
150
+ labels (1d int tensor) : labels in {0, ..., self.num_classes - 1} of same shape as scores
151
+ available_thresholds (optional 1d float tensor) : thresholds which will be considered when tuning.
152
+ +/-inf will be added automatically to ensure all examples are classified. If omitted, will
153
+ insert thresholds between each element of sorted(unique(scores)).
154
+
155
+ Returns
156
+ -------
157
+ float tensor : optimal mean cost achieved on the provided dataset at the tuned `self.thresholds`
158
+
159
+ """
160
+ inf = torch.tensor([torch.inf], device=scores.device)
161
+ if available_thresholds is None: # use all possible thresholds
162
+ unique_scores = torch.unique(scores)
163
+ available_thresholds = (unique_scores[:-1] + unique_scores[1:]) / 2.0
164
+ # Always allow some classes to be omitted entirely by setting thresholds to +/- inf.
165
+ # This simplifies the algorithm and also guarantees that the baseline constant-output
166
+ # classifiers are feasible choices for tuning, which is needed to assure that the
167
+ # optimum is at least as good as a constant-output classifier.
168
+ available_thresholds = torch.concatenate(
169
+ [
170
+ -inf,
171
+ available_thresholds,
172
+ inf,
173
+ ]
174
+ )
175
+ indices = torch.empty(
176
+ (self.num_classes - 2, len(available_thresholds)),
177
+ dtype=torch.int,
178
+ device=scores.device,
179
+ )
180
+
181
+ # cost[j] = optimal total cost of items assigned pred <= c if the
182
+ # threshold between class c and c+1 is available_thresholds[j] (by appropriate choice of lower thresholds).
183
+ cost, _ = self.dp_step(
184
+ c_idx=0,
185
+ scores=scores,
186
+ labels=labels,
187
+ available_thresholds=available_thresholds,
188
+ )
189
+ for c in range(1, self.num_classes - 1):
190
+ cost, indices[c - 1, :] = self.dp_step(
191
+ c_idx=c,
192
+ scores=scores,
193
+ labels=labels,
194
+ available_thresholds=available_thresholds,
195
+ prev_cost=cost,
196
+ )
197
+ cost, best_index = self.dp_step(
198
+ c_idx=self.num_classes - 1,
199
+ scores=scores,
200
+ labels=labels,
201
+ available_thresholds=available_thresholds,
202
+ prev_cost=cost,
203
+ )
204
+ if self.direction == "min":
205
+ cost *= -1
206
+
207
+ # Follow DP path backwards to find thresholds which optimized cost
208
+ self.thresholds[self.num_classes - 2] = available_thresholds[
209
+ best_index
210
+ ] # final threshold
211
+ for c in range(self.num_classes - 3, -1, -1): # counting down to zero
212
+ best_index = indices[c, best_index.long()]
213
+ self.thresholds[c] = available_thresholds[best_index.long()]
214
+
215
+ return cost
216
+
217
+
218
+ def cumsum_with_0(t: torch.Tensor):
219
+ return torch.nn.functional.pad(torch.cumsum(t, dim=0), (1, 0))
220
+
221
+
222
+ class OptimalCostPerSampleOrdinalThresholding(
223
+ OptimalOrdinalThresholdingViaDynamicProgramming, ABC
224
+ ):
225
+ """Optimal 1d thresholding based on tuning thresholds to optimize the mean of a sample-wise cost function."""
226
+
227
+ @abstractmethod
228
+ def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
229
+ """Compute the sample-wise cost of assigning label(s) `preds` when the ground truth is `labels`."""
230
+
231
+ def mean_cost(
232
+ self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]
233
+ ) -> torch.Tensor:
234
+ """Compute the mean cost of assigning label(s) `preds` when the ground truth is `labels`."""
235
+ return torch.mean(self.cost(labels=labels, preds=preds))
236
+
237
+ def dp_step(
238
+ self,
239
+ c_idx: int,
240
+ *,
241
+ scores: torch.Tensor,
242
+ labels: torch.Tensor,
243
+ available_thresholds: torch.Tensor,
244
+ prev_cost: Optional[torch.Tensor] = None,
245
+ ) -> (torch.Tensor, Optional[torch.Tensor]):
246
+ """O(len(scores)) implementation for per-sample cost."""
247
+ # Compute running_cost[i] = sum of costs of elements with score less than available_thresholds[i] if assigned label c
248
+ item_costs = self.cost(labels=labels, preds=c_idx) / len(scores)
249
+ if self.direction == "min":
250
+ item_costs *= -1
251
+ # move tensors to and from CPU because histogram has no CUDA implementation
252
+ cost_new_class_by_thresh, _ = torch.histogram(
253
+ scores.cpu().float(),
254
+ weight=item_costs.cpu().float(),
255
+ bins=available_thresholds.cpu().float(),
256
+ )
257
+ running_cost = cumsum_with_0(cost_new_class_by_thresh.to(labels.device))
258
+
259
+ # Combine with running_cost with prev_cost
260
+ if c_idx == 0:
261
+ return running_cost, None
262
+ diff = prev_cost - running_cost
263
+ cummax = torch.cummax(diff, dim=0)
264
+ cost = running_cost + cummax.values
265
+ if c_idx == self.num_classes - 1:
266
+ # -1 to always set the *upper* threshold for class `num_classes - 1` to include the rest of the data
267
+ return cost[-1], cummax.indices[-1]
268
+ return cost, cummax.indices
269
+
270
+
271
+ class MaxAccuracyOrdinalThresholding(OptimalCostPerSampleOrdinalThresholding):
272
+ """Threshold to maximize accuracy."""
273
+
274
+ direction = "max"
275
+
276
+ def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
277
+ return torch.eq(labels, preds).float()
278
+
279
+
280
+ class MaxMacroRecallOrdinalThresholding(OptimalCostPerSampleOrdinalThresholding):
281
+ """Threshold to maximize macro-averaged recall."""
282
+
283
+ direction = "max"
284
+
285
+ def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
286
+ counts = torch.bincount(labels, minlength=self.num_classes).float()
287
+ ratios = counts.sum() / (self.num_classes * counts)
288
+ return torch.eq(labels, preds).float() * torch.gather(
289
+ ratios, 0, labels.type(torch.int64)
290
+ )
291
+
292
+
293
+ class MinAbsoluteErrorOrdinalThresholding(OptimalCostPerSampleOrdinalThresholding):
294
+ """Threshold to minimize mean absolute error."""
295
+
296
+ direction = "min"
297
+
298
+ def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
299
+ return torch.abs(preds - labels).float()
300
+
301
+
302
+ class ClassWeightedOptimalCostPerSampleOrdinalThresholding(
303
+ OptimalCostPerSampleOrdinalThresholding
304
+ ):
305
+ """Compute cost weighted equally over classes instead of equally over samples.
306
+
307
+ This class takes another instance of OptimalCostPerSampleOrdinalThresholding
308
+ which computes its cost independently for each sample and reweights the cost
309
+ based on label frequencies.
310
+
311
+ Note: this class depends on an implementation detail of its superclass:
312
+ namely calling `self.cost` with the full tuning or eval set of labels,
313
+ rather than a single label. This is required to do the re-weighting properly.
314
+
315
+ """
316
+
317
+ def __init__(self, unweighted_instance: OptimalCostPerSampleOrdinalThresholding):
318
+ self.direction = unweighted_instance.direction
319
+ super().__init__(unweighted_instance.num_classes)
320
+ self.unweighted_instance = unweighted_instance
321
+
322
+ def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
323
+ counts = torch.bincount(labels, minlength=self.num_classes)
324
+ (indices,) = torch.where(counts == 0)
325
+ if len(indices) > 0:
326
+ raise ValueError(
327
+ f"Cannot compute class-weighted cost because classes {set(indices.tolist())} are missing."
328
+ )
329
+ unweighted_cost = self.unweighted_instance.cost(labels=labels, preds=preds)
330
+ weights = len(labels) / (self.num_classes * counts[labels].float())
331
+ return weights * unweighted_cost
332
+
333
+
334
+ class OptimalCostPerClassOrdinalThresholding(
335
+ OptimalOrdinalThresholdingViaDynamicProgramming, ABC
336
+ ):
337
+ """General DP case for when the linear algorithm for per-sample costs is not applicable.
338
+
339
+ Complexity depends on the implementation of `cost_matrix`.
340
+
341
+ """
342
+
343
+ @abstractmethod
344
+ def cost_matrix(
345
+ self,
346
+ c_idx: int,
347
+ *,
348
+ scores: torch.Tensor,
349
+ labels: torch.Tensor,
350
+ available_thresholds: torch.Tensor,
351
+ start: bool,
352
+ end: bool,
353
+ ) -> torch.Tensor:
354
+ """Each output[i, j] = cost for when scores in range `available_thresholds[i:j]` are assigned label `c_idx`."""
355
+
356
+ def mean_cost(
357
+ self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]
358
+ ) -> torch.Tensor:
359
+ """Compute the mean cost of assigning label(s) `preds` when the ground truth is `labels`."""
360
+
361
+ if isinstance(preds, int) or preds.numel() == 1:
362
+ preds = preds * torch.ones_like(labels, dtype=torch.int)
363
+
364
+ total_cost = torch.tensor(0.0, device=labels.device)
365
+ for c_idx in range(self.num_classes):
366
+ thresholds = torch.tensor([c_idx - 0.5, c_idx + 0.5], device=labels.device)
367
+ total_cost += self.cost_matrix(
368
+ c_idx, preds.float(), labels, thresholds, start=True, end=True
369
+ )[0, 0]
370
+ return total_cost / self.num_classes
371
+
372
+ def dp_step(
373
+ self,
374
+ c_idx: int,
375
+ *,
376
+ scores: torch.Tensor,
377
+ labels: torch.Tensor,
378
+ available_thresholds: torch.Tensor,
379
+ prev_cost: Optional[torch.Tensor] = None,
380
+ ) -> (torch.Tensor, Optional[torch.Tensor]):
381
+ cost_matrix = (
382
+ self.cost_matrix(
383
+ c_idx,
384
+ scores=scores,
385
+ labels=labels,
386
+ available_thresholds=available_thresholds,
387
+ start=c_idx == 0,
388
+ end=c_idx == self.num_classes - 1,
389
+ )
390
+ / self.num_classes
391
+ )
392
+ if self.direction == "min":
393
+ cost_matrix *= -1
394
+ if prev_cost is not None:
395
+ cost_matrix += prev_cost[:, None]
396
+ max_ = torch.max(cost_matrix, dim=0)
397
+ return max_.values, max_.indices
398
+
399
+
400
+ def _compute_metrics_matrices(
401
+ scores: torch.Tensor,
402
+ binary_labels: torch.Tensor,
403
+ thresholds: torch.Tensor,
404
+ start: bool = False,
405
+ end: bool = False,
406
+ ) -> tuple[torch.Tensor, torch.Tensor]:
407
+ """Each output[i, j] = stats for when scores between thresholds[i] and thresolds[j] are assigned `True`.
408
+
409
+ Helper function for `MaxMacroPrecisionOrdinalThresholding` and `MaxMacroF1OrdinalThresholding`
410
+
411
+ Computed in O(len(thresholds)**2 + len(scores)*log(len(thresholds))) operations instead of the naive
412
+ O(len(scores)*len(thresholds)**2) operations to compute each element of the output independently.
413
+
414
+ Arguments
415
+ ---------
416
+ scores (float Tensor) : scores of labeled examples for which to compute metrics
417
+ binary_labels (bool Tensor) : corresponding binary labels of same shape as `scores`
418
+ thresholds (float Tensor) : thresholds between which to compute metrics
419
+ start : compute only the first row of the output (lower threshold at its minimum value)
420
+ end : compute only the last column of the output (upper threshold at its maximum value)
421
+
422
+ Returns
423
+ -------
424
+ tp : tp[i, j] = number of true positives if scores between thresholds[i:j] are classified as positive
425
+ tp_plus_fp: tp_plus_fp[i, j] = number of scores between thresholds[i:j]
426
+
427
+ """
428
+ # move tensors to and from CPU because histogram has no CUDA implementation
429
+ scores = scores.float().cpu()
430
+ thresholds = thresholds.float().cpu()
431
+ labeled_true_by_thresh, _ = torch.histogram(
432
+ scores,
433
+ weight=binary_labels.float().cpu(),
434
+ bins=thresholds,
435
+ )
436
+ count_by_thresh, _ = torch.histogram(
437
+ scores,
438
+ bins=thresholds,
439
+ )
440
+ running_labeled_true_by_thresh = cumsum_with_0(
441
+ labeled_true_by_thresh.to(binary_labels.device)
442
+ )
443
+ running_count_by_thresh = cumsum_with_0(
444
+ count_by_thresh.to(binary_labels.device).float()
445
+ )
446
+
447
+ def start_slice(t):
448
+ return t[: (1 if start else None), None]
449
+
450
+ def end_slice(t):
451
+ return t[None, (-1 if end else None) :]
452
+
453
+ tp = end_slice(running_labeled_true_by_thresh) - start_slice(
454
+ running_labeled_true_by_thresh
455
+ )
456
+ tp_plus_fp = end_slice(running_count_by_thresh) - start_slice(
457
+ running_count_by_thresh
458
+ )
459
+ return tp, tp_plus_fp
460
+
461
+
462
+ class MaxMacroPrecisionOrdinalThresholding(OptimalCostPerClassOrdinalThresholding):
463
+ """Threshold to maximize macro-averaged precision."""
464
+
465
+ direction = "max"
466
+
467
+ def cost_matrix(
468
+ self,
469
+ c_idx: int,
470
+ scores: torch.Tensor,
471
+ labels: torch.Tensor,
472
+ available_thresholds: torch.Tensor,
473
+ start: bool,
474
+ end: bool,
475
+ ) -> torch.Tensor:
476
+ tp, tp_plus_fp = _compute_metrics_matrices(
477
+ scores, torch.eq(labels, c_idx), available_thresholds, start=start, end=end
478
+ )
479
+ safe_tp_plus_fp = torch.maximum(
480
+ tp_plus_fp, torch.ones(1, device=tp_plus_fp.device)
481
+ )
482
+ return torch.where(torch.ge(tp_plus_fp, 0.0), tp / safe_tp_plus_fp, -torch.inf)
483
+
484
+
485
+ class MaxMacroF1OrdinalThresholding(OptimalCostPerClassOrdinalThresholding):
486
+ """Threshold to maximize macro-averaged F1 score."""
487
+
488
+ direction = "max"
489
+
490
+ def cost_matrix(
491
+ self,
492
+ c_idx: int,
493
+ scores: torch.Tensor,
494
+ labels: torch.Tensor,
495
+ available_thresholds: torch.Tensor,
496
+ start: bool,
497
+ end: bool,
498
+ ) -> torch.Tensor:
499
+ tp, tp_plus_fp = _compute_metrics_matrices(
500
+ scores, torch.eq(labels, c_idx), available_thresholds, start=start, end=end
501
+ )
502
+ tp_plus_fn = torch.eq(labels, c_idx).float().sum() # scalar
503
+ safe_tp_plus_fp = torch.maximum(
504
+ tp_plus_fp, torch.ones(1, device=tp_plus_fp.device)
505
+ )
506
+ return torch.where(
507
+ torch.ge(tp_plus_fp, 0.0),
508
+ 2 * tp / (safe_tp_plus_fp + tp_plus_fn),
509
+ -torch.inf,
510
+ )