Upload 10 files
Browse files- README.md +218 -3
- config.py +71 -0
- dam3.1.ckpt +3 -0
- featex.py +119 -0
- model.py +185 -0
- pipeline.py +43 -0
- requirements.txt +5 -0
- tuning/__init__.py +0 -0
- tuning/indet_roc.py +416 -0
- tuning/optimal_ordinal.py +510 -0
README.md
CHANGED
|
@@ -1,3 +1,218 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
base_model:
|
| 6 |
+
- openai/whisper-small.en
|
| 7 |
+
pipeline_tag: audio-classification
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Background
|
| 12 |
+
|
| 13 |
+
In the United States nearly 21M adults suffer from depression each year [1], with depression serving as the nation’s leading cause of disability [2].
|
| 14 |
+
Despite this, less than 4% of Americans receive mental health screenings from their primary care physicians during annual wellness visits.
|
| 15 |
+
The pandemic and public campaigns of late have made strides toward positively increasing awareness of mental health struggles, but there remains a persisting stigma around depression and other mental health conditions.
|
| 16 |
+
The influence of this stigma is especially marked in older adults. People aged 65 and older are less likely than any other age group to seek mental health support.
|
| 17 |
+
Older adults – for whom depression significantly increases the risk of disability and morbidity – also tend to underreport mental health symptoms [3].
|
| 18 |
+
|
| 19 |
+
In the US, this outlook becomes even more troubling when coupled with the rate at which the country’s population is aging: 1 out of every 6 people will be 60 years or over by 2030 [4].
|
| 20 |
+
As widespread and prevalent as depression is, identifying and treating depression and other mental health conditions remains challenging and there is limited objectivity in the screening processes.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Depression–Anxiety Model (DAM)
|
| 24 |
+
|
| 25 |
+
## Model Overview
|
| 26 |
+
|
| 27 |
+
DAM is a clinical-grade, speech-based model designed to screen for signs of depression and anxiety using voice biomarkers.
|
| 28 |
+
To the best of our knowledge, it is the first model developed explicitly for clinical-grade mental health assessment from speech without reliance on linguistic content or transcription.
|
| 29 |
+
The model operates exclusively on the acoustic properties of the speech signal, extracting depression- and anxiety-specific voice biomarkers rather than semantic or lexical information.
|
| 30 |
+
Numerous studies [5–7] have demonstrated that paralinguistic features – such as spectral entropy, pitch variability, fundamental frequency, and related acoustic measures – exhibit strong correlations with depression and anxiety.
|
| 31 |
+
Building on this body of evidence, DAM extends prior approaches by leveraging deep learning to learn fine-grained vocal biomarkers directly from the raw speech signal, yielding representations that demonstrate greater predictive power than hand-engineered paralinguistic features.
|
| 32 |
+
DAM analyzes spoken audio to estimate depression and anxiety severity scores which can be subsequently mapped to standardized clinical scales, such as **PHQ-9** (Patient Health Questionnaire-9) for depression and **GAD-7** (Generalized Anxiety Disorder-7) for anxiety.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## Data
|
| 36 |
+
|
| 37 |
+
The model was trained and evaluated on a large-scale speech dataset collected from approximately 35,000 individuals via phone, tablet, or web app, which corresponds to ~863 hours of speech data.
|
| 38 |
+
Ground-truth labels were derived from both clinician-administered and self-reported PHQ-9 and GAD-7 questionnaires, ensuring strong alignment with established clinical assessment standards.
|
| 39 |
+
The data consists predominantly of American English speech. However, a broad range of accents is represented, providing robustness across diverse speaking styles.
|
| 40 |
+
|
| 41 |
+
The audio data itself cannot be shared for privacy reasons. Demographic statistics, model scores, and associated metadata for each audio stream are available for threshold tuning at https://huggingface.co/datasets/KintsugiHealth/dam-dataset.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
## Model Architecture
|
| 45 |
+
|
| 46 |
+
**Foundation model:** OpenAI Whisper-Small EN
|
| 47 |
+
|
| 48 |
+
**Training approach:** Fine-tuning + Multi-task learning
|
| 49 |
+
|
| 50 |
+
**Downstream tasks:** Depression and anxiety severity estimation
|
| 51 |
+
|
| 52 |
+
Whisper serves as the backbone for extracting voice biomarkers, while multi-task head is fine-tuned jointly on depression and anxiety prediction tasks to leverage shared representations across mental health conditions.
|
| 53 |
+
|
| 54 |
+
## Input Requirements
|
| 55 |
+
|
| 56 |
+
**Preferred minimum audio length:** 30 seconds of speech after Voice Activity Detector
|
| 57 |
+
|
| 58 |
+
**Input modality:** Audio only
|
| 59 |
+
|
| 60 |
+
Shorter audio samples may lead to reduced prediction accuracy.
|
| 61 |
+
|
| 62 |
+
## Output
|
| 63 |
+
|
| 64 |
+
The model outputs a dictionary of the following form `{"depression":score, "anxiety": score}`.
|
| 65 |
+
|
| 66 |
+
If `quantized=False` (see the Usage section below), the scores are returned as raw float values which correlate monotonically with PHQ-9 and GAD-7.
|
| 67 |
+
|
| 68 |
+
If `quantized=True` the scores are converted into integers representing the severity of depression and anxiety.
|
| 69 |
+
|
| 70 |
+
**Quantization levels for depression task:**
|
| 71 |
+
|
| 72 |
+
0 – no depression (PHQ-9 <= 9)
|
| 73 |
+
|
| 74 |
+
1 – mild to moderate depression (10 <= PHQ-9 <= 14)
|
| 75 |
+
|
| 76 |
+
2 – severe depression (PHQ-9 >= 15)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
**Quantization levels for anxiety task:**
|
| 80 |
+
|
| 81 |
+
0 – no anxiety (GAD-7 <= 4)
|
| 82 |
+
|
| 83 |
+
1 – mild anxiety (5 <= GAD-7 <= 9)
|
| 84 |
+
|
| 85 |
+
2 – moderate anxiety (10 <= GAD-7 <= 14)
|
| 86 |
+
|
| 87 |
+
3 – severe anxiety (GAD-7 >= 15)
|
| 88 |
+
|
| 89 |
+
## Intended Use
|
| 90 |
+
* Mental health research
|
| 91 |
+
* Clinical decision support
|
| 92 |
+
* Continuous monitoring of depression and anxiety
|
| 93 |
+
|
| 94 |
+
## Limitations
|
| 95 |
+
* Not intended for diagnosis/self-diagnosis without clinical oversight
|
| 96 |
+
* Performance may degrade on speech recorded outside controlled environments or in the presence of noise
|
| 97 |
+
* Intended only for audio containing a single voice speaking English
|
| 98 |
+
* Biases related to language, accent, or demographic representation may be present
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Usage
|
| 102 |
+
1. Checkout the repo:
|
| 103 |
+
|
| 104 |
+
```
|
| 105 |
+
git clone https://huggingface.co/KintsugiHealth/dam
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
2. Install requirements:
|
| 109 |
+
```python
|
| 110 |
+
pip install -r requirements.txt
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
3. Load and run pipeline
|
| 114 |
+
```python
|
| 115 |
+
from pipeline import Pipeline
|
| 116 |
+
|
| 117 |
+
pipeline = Pipeline()
|
| 118 |
+
result = pipeline.run_on_file("sample.wav", quantized=True)
|
| 119 |
+
print(result)
|
| 120 |
+
```
|
| 121 |
+
The output will resemble a dictionary, for example {'depression': 2, 'anxiety': 3}, indicating that the analyzed audio sample exhibits voice biomarkers consistent with severe depression and severe anxiety.
|
| 122 |
+
|
| 123 |
+
## Tuning Thresholds
|
| 124 |
+
As mentioned in the Data section above, the raw audio data cannot be shared, but validation and test sets of model scores associated with ground truth and demographic metadata are available for threshold tuning. This way thresholds can be tuned for traditional binary classification, ternary classification with an indeterminate output, and multi-class classification of severity. Two modules are provided for this in the model code's `tuning` package, as illustrated below.
|
| 125 |
+
|
| 126 |
+
### Tuning Sensitivity, Specificity, and Indeterminate Fraction
|
| 127 |
+
This module implements a generalization of ROC curve analysis wherein ground truth is binary, but model output can be negative (score below lower threshold), positive (score above upper threshold), or indeterminate (score between thresholds). For the purpose of metric calculations such as sensitivity and specificity, examples marked indeterminate do not count towards either the numerator or denominator. The budget for fraction of examples to be marked indeterminate is configurable as shown below.
|
| 128 |
+
```
|
| 129 |
+
import numpy as np
|
| 130 |
+
|
| 131 |
+
from datasets import load_dataset
|
| 132 |
+
from tuning.indet_roc import BinaryLabeledScores
|
| 133 |
+
|
| 134 |
+
val = load_dataset("KintsugiHealth/dam-dataset", split="validation")
|
| 135 |
+
val.set_format("numpy")
|
| 136 |
+
test = load_dataset("KintsugiHealth/dam-dataset", split="test")
|
| 137 |
+
test.set_format("numpy")
|
| 138 |
+
|
| 139 |
+
data = dict(val=val, test=test)
|
| 140 |
+
|
| 141 |
+
# Associate depression model scores with binarized labels based on whether the PHQ-9 sum is >= 10
|
| 142 |
+
scores_labeled = {
|
| 143 |
+
k: BinaryLabeledScores(
|
| 144 |
+
y_score=v['scores_depression'], # Change to 'scores_anxiety' to calibrate anxiety thresholds
|
| 145 |
+
y_true=(v['phq'] >= 10).astype(int) # Change to 'gad' to calibrate anxiety thresholds; optionally change cutoff
|
| 146 |
+
)
|
| 147 |
+
for k, v in data.items()
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
issa = scores_labeled['val'].indet_sn_sp_array() # Metrics at all possible lower, upper threshold pairs
|
| 151 |
+
|
| 152 |
+
# Compute ROC curve with 20% indeterminate budget and select a point near the diagonal
|
| 153 |
+
roc_at_20 = issa.roc_curve(0.2) # Pareto frontier of (sensitivity, specificity) pairs with at most 20% indeterminate fraction
|
| 154 |
+
print(f"Area under the ROC curve with 20% indeterminate budget: {roc_at_20.auc()=:.1%}") #
|
| 155 |
+
sn_eq_sp_at_20 = roc_at_20.sn_eq_sp() # Find where ROC comes closest to sensitivity = specificity diagonal
|
| 156 |
+
print(f"Thresholds to balance sensitivity and specificity on val set with 20% indeterminate budget: "
|
| 157 |
+
f"{sn_eq_sp_at_20.lower_thresh=:.3}, {sn_eq_sp_at_20.upper_thresh=:.3}")
|
| 158 |
+
print(f"Performance on val set with these thresholds: {sn_eq_sp_at_20.sn=:.1%}, {sn_eq_sp_at_20.sp=:.1%}") #
|
| 159 |
+
test_metrics = sn_eq_sp_at_20.eval(**scores_labeled['test']._asdict()) # Thresholds evaluated on test set
|
| 160 |
+
print(f"Performance on test set with these thresholds: {test_metrics.sn=:.1%}, {test_metrics.sp=:.1%}") #
|
| 161 |
+
|
| 162 |
+
# Find best specificity given sensitivity and indeterminate budget constraints
|
| 163 |
+
constrained = issa[(issa.sn >= 0.8) & (issa.indet_frac <= 0.35)]
|
| 164 |
+
optimal = constrained[np.argmax(constrained.sp)]
|
| 165 |
+
print(f"Highest specificity achievable with sensitivity >= 80% and 35% indeterminate budget is "
|
| 166 |
+
f"{optimal.sp=:.1%}, achieved at thresholds {optimal.lower_thresh=:.3}, {optimal.upper_thresh=:.3}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Collect optimal ways of achieving balanced sensitivity and specificity as a function of indeterminate fraction
|
| 170 |
+
sn_eq_sp = issa.sn_eq_sp_graph()
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
### Optimal Tuning for Multiclass Tasks
|
| 174 |
+
The depression and anxiety models were each trained with ordinal regression to predict a scalar score monotonically correlated with the underlying PHQ-9 and GAD-7 questionnaire ground truth sums. As such there are efficient dynamic programming algorithms to select optimal thresholds for multi-class numeric labels under a variety of decision criteria.
|
| 175 |
+
|
| 176 |
+
```
|
| 177 |
+
from datasets import load_dataset
|
| 178 |
+
from tuning.optimal_ordinal import MinAbsoluteErrorOrdinalThresholding
|
| 179 |
+
|
| 180 |
+
val = load_dataset("KintsugiHealth/dam-dataset", split="validation")
|
| 181 |
+
val.set_format("torch")
|
| 182 |
+
test = load_dataset("KintsugiHealth/dam-dataset", split="test")
|
| 183 |
+
test.set_format("torch")
|
| 184 |
+
|
| 185 |
+
data = dict(val=val, test=test)
|
| 186 |
+
|
| 187 |
+
scores = val['scores_anxiety'] # Change to 'scores_depression' for depression threshold tuning
|
| 188 |
+
labels = val['gad'] # Change to 'phq' for depression threshold tuning; optionally change to quantized version for coarser prediction tuning
|
| 189 |
+
|
| 190 |
+
# Can change to any of
|
| 191 |
+
# `MaxAccuracyOrdinalThresholding`
|
| 192 |
+
# `MaxMacroRecallOrdinalThresholding`
|
| 193 |
+
# `MaxMacroPrecisionOrdinalThresholding`
|
| 194 |
+
# `MaxMacroF1OrdinalThresholding`
|
| 195 |
+
optimal_thresh = MinAbsoluteErrorOrdinalThresholding(num_classes=int(labels.max()) + 1)
|
| 196 |
+
best_constant_cost, best_constant = optimal_thresh.best_constant_output_classifier(labels)
|
| 197 |
+
print(f"Always predicting GAD sum = {best_constant} on val set independent of model score gives mean absolute error {best_constant_cost:.3}.")
|
| 198 |
+
mean_error = optimal_thresh.tune_thresholds(labels=labels, scores=scores)
|
| 199 |
+
print(f"Thresholds optimized on val set to predict GAD sum from anxiety score: {optimal_thresh.thresholds}")
|
| 200 |
+
print(f"Mean absolute error predicting GAD sum on val set based on thresholds optimized on val set: {mean_error:.3}")
|
| 201 |
+
test_preds = optimal_thresh(test['scores_anxiety'])
|
| 202 |
+
mean_error_test = optimal_thresh.mean_cost(labels=test['gad'], preds=test_preds)
|
| 203 |
+
print(f"Mean absolute error predicting GAD sum on test set based on thresholds optimized on val set: {mean_error_test:.3}")
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
# Acknowledgments
|
| 207 |
+
|
| 208 |
+
This model was created through equal contributions by Oleksii Abramenko, Noah Stein, and Colin Vaz while at Kintsugi Health. For a full list of contributors to earlier modeling projects, data collection, clinical, and business matters, see the organization card at https://huggingface.co/KintsugiHealth.
|
| 209 |
+
|
| 210 |
+
# References
|
| 211 |
+
|
| 212 |
+
1. https://www.nimh.nih.gov/health/statistics/major-depression
|
| 213 |
+
2. https://www.hopefordepression.org/depression-facts/
|
| 214 |
+
3. https://nndc.org/facts/
|
| 215 |
+
4. https://www.psychiatry.org/patients-families/stigma-and-discrimination
|
| 216 |
+
5. https://www.sciencedirect.com/science/article/pii/S1746809423004536
|
| 217 |
+
6. https://pmc.ncbi.nlm.nih.gov/articles/PMC3409931/
|
| 218 |
+
7. https://pmc.ncbi.nlm.nih.gov/articles/PMC11559157
|
config.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration for running Kintsugi Depression and Anxiety model."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
EXPECTED_SAMPLE_RATE = 16000 # Audio sample rate in hertz
|
| 6 |
+
|
| 7 |
+
# Configuration for running Kintsugi Depression and Anxiety model as intended
|
| 8 |
+
default_config = {
|
| 9 |
+
# See featex.py for preprocessor config details
|
| 10 |
+
'preprocessor_config': {
|
| 11 |
+
'normalize_features': True,
|
| 12 |
+
'chunk_seconds': 30,
|
| 13 |
+
'max_overlap_frac': 0.0,
|
| 14 |
+
'pad_last_chunk_to_full': True,
|
| 15 |
+
},
|
| 16 |
+
|
| 17 |
+
# See model.py for backbone config details
|
| 18 |
+
'backbone_configs': {'audio': {'model': 'openai/whisper-small.en',
|
| 19 |
+
'hf_config': {'encoder_layerdrop': 0.0,
|
| 20 |
+
'dropout': 0.0,
|
| 21 |
+
'activation_dropout': 0.0},
|
| 22 |
+
'lora_params': {'r': 32,
|
| 23 |
+
'lora_alpha': 64.0,
|
| 24 |
+
'target_modules': 'all-linear',
|
| 25 |
+
'lora_dropout': 0.4,
|
| 26 |
+
'modules_to_save': ['conv1', 'conv2'],
|
| 27 |
+
'bias': 'all'}},
|
| 28 |
+
'llma': {'model': 'openai/whisper-small.en',
|
| 29 |
+
'hf_config': {'encoder_layerdrop': 0.0,
|
| 30 |
+
'dropout': 0.0,
|
| 31 |
+
'activation_dropout': 0.0}}},
|
| 32 |
+
|
| 33 |
+
# See model.py for classifier config details
|
| 34 |
+
'classifier_config': {'shared_projection_dim': [256, 64],
|
| 35 |
+
'tasks': {'depression': {'proj_dim': 128, 'dropout': 0.4},
|
| 36 |
+
'anxiety': {'proj_dim': 128, 'dropout': 0.4}}},
|
| 37 |
+
|
| 38 |
+
# Score thresholds chosen to optimize macro average F1 score on validation set
|
| 39 |
+
'inference_thresholds': {
|
| 40 |
+
# Three-level depression severity model:
|
| 41 |
+
# depression score <= -0.6699 --> no depression (PHQ-9 <= 9)
|
| 42 |
+
# -0.6699 < depression score <= -0.2908 --> mild to moderate depression (10 <= PHQ-9 <= 14)
|
| 43 |
+
# -0.2908 < depression score --> severe depression (PHQ-9 >= 15)
|
| 44 |
+
'depression': [-0.6699, -0.2908],
|
| 45 |
+
# Four-level anxiety severity model:
|
| 46 |
+
# anxiety score <= -0.7939 --> no anxiety (GAD-7 <= 4)
|
| 47 |
+
# -0.7939 < anxiety score <= -0.2173 --> mild anxiety (5 <= GAD-7 <= 9)
|
| 48 |
+
# -0.2173 < anxiety score <= 0.1521 --> moderate anxiety (10 <= GAD-7 <= 14)
|
| 49 |
+
# 0.1521 < anxiety score --> severe anxiety (GAD-7 >= 15)
|
| 50 |
+
'anxiety': [-0.7939, -0.2173, 0.1521]
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# Average filter bank energies used for feature normalization
|
| 55 |
+
logmel_energies = torch.tensor([0.34912264, 0.58558977, 0.7912451 , 0.92767584, 0.98273695,
|
| 56 |
+
0.98439455, 0.9603633 , 0.93906444, 0.9366281 , 0.93200225,
|
| 57 |
+
0.916437 , 0.8928787 , 0.8637211 , 0.83265126, 0.79977655,
|
| 58 |
+
0.7778334 , 0.7561299 , 0.72997606, 0.70391226, 0.6800474 ,
|
| 59 |
+
0.65755 , 0.63536274, 0.61355984, 0.5923383 , 0.5720056 ,
|
| 60 |
+
0.55244887, 0.53684795, 0.5221597 , 0.5098636 , 0.49923953,
|
| 61 |
+
0.48908615, 0.47840047, 0.46758702, 0.47343993, 0.46268672,
|
| 62 |
+
0.4475126 , 0.46747103, 0.45131385, 0.4635319 , 0.44889897,
|
| 63 |
+
0.45491976, 0.4373785 , 0.43154317, 0.42194438, 0.41158468,
|
| 64 |
+
0.40096927, 0.3933149 , 0.38795966, 0.38441542, 0.38454026,
|
| 65 |
+
0.3815766 , 0.3768835 , 0.3719921 , 0.3654539 , 0.35399568,
|
| 66 |
+
0.3425986 , 0.32823247, 0.31404305, 0.30564603, 0.29617435,
|
| 67 |
+
0.29273877, 0.28560263, 0.27459458, 0.26876706, 0.25825337,
|
| 68 |
+
0.24759005, 0.24090728, 0.2344712 , 0.22529823, 0.20880115,
|
| 69 |
+
0.193578 , 0.18290243, 0.17621627, 0.17087021, 0.16641389,
|
| 70 |
+
0.15932252, 0.14312662, 0.11790597, 0.08030523, 0.03747071],
|
| 71 |
+
)
|
dam3.1.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfa897e1b990de9377b2fb805b526a36fe6f31de01bbc8c3d288d317df2c4b0c
|
| 3 |
+
size 736180146
|
featex.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocessing and normalization to prepare audio for Kintsugi Depression and Anxiety model."""
|
| 2 |
+
from typing import Union, BinaryIO
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
from transformers import AutoFeatureExtractor
|
| 8 |
+
|
| 9 |
+
from config import EXPECTED_SAMPLE_RATE, logmel_energies
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_audio(source: Union[BinaryIO, str, os.PathLike]) -> torch.Tensor:
|
| 13 |
+
"""Load audio file, verify mono channel count, and resample if necessary.
|
| 14 |
+
|
| 15 |
+
Parameters
|
| 16 |
+
----------
|
| 17 |
+
source: open file or path to file
|
| 18 |
+
|
| 19 |
+
Returns
|
| 20 |
+
-------
|
| 21 |
+
Time domain audio samples as a 1 x num_samples float tensor sampled at 16 kHz.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
audio, fs = torchaudio.load(source)
|
| 25 |
+
if audio.shape[0] != 1:
|
| 26 |
+
raise ValueError(f"Provided audio has {audio.shape[0]} != 1 channels.")
|
| 27 |
+
if fs != EXPECTED_SAMPLE_RATE:
|
| 28 |
+
audio = torchaudio.functional.resample(audio, fs, EXPECTED_SAMPLE_RATE)
|
| 29 |
+
return audio
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Preprocessor:
|
| 33 |
+
def __init__(self,
|
| 34 |
+
normalize_features: bool = True,
|
| 35 |
+
chunk_seconds: int = 30,
|
| 36 |
+
max_overlap_frac: float = 0.0,
|
| 37 |
+
pad_last_chunk_to_full: bool = True,
|
| 38 |
+
):
|
| 39 |
+
"""Create preprocessor object.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
normalize_features: Whether the Whisper preprocessor should normalize features
|
| 44 |
+
chunk_seconds: Size of model's receptive field in seconds
|
| 45 |
+
max_overlap_frac: Fraction of each chunk allowed to overlap previous chunk for inputs longer than chunk_seconds
|
| 46 |
+
pad_last_chunk_to_full: Whether to pad audio to an integer multiple of chunk_seconds
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
self.preprocessor = AutoFeatureExtractor.from_pretrained("openai/whisper-small.en")
|
| 50 |
+
self.normalize_features = normalize_features
|
| 51 |
+
self.chunk_seconds = chunk_seconds
|
| 52 |
+
self.max_overlap_frac = max_overlap_frac
|
| 53 |
+
self.pad_last_chunk_to_full = pad_last_chunk_to_full
|
| 54 |
+
|
| 55 |
+
def preprocess_with_audio_normalization(
|
| 56 |
+
self,
|
| 57 |
+
audio: torch.Tensor,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""Run Whisper preprocessor and normalization expected by the model.
|
| 60 |
+
|
| 61 |
+
Note: some normalization steps can be avoided, but are included to match
|
| 62 |
+
feature extraction used during training.
|
| 63 |
+
|
| 64 |
+
Parameters
|
| 65 |
+
----------
|
| 66 |
+
audio: Raw audio samples as a 1 x num_samples float tensor sampled at 16 kHz
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
Normalized mel filter bank features as a float tensor of shape
|
| 71 |
+
num_chunks x 80 mel filter bands x 3000 time frames
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
# Remove DC offset and scale amplitude to [-1, 1]
|
| 75 |
+
audio = torch.squeeze(audio, 0)
|
| 76 |
+
audio = audio - torch.mean(audio)
|
| 77 |
+
audio = audio / torch.max(torch.abs(audio))
|
| 78 |
+
|
| 79 |
+
chunk_samples = EXPECTED_SAMPLE_RATE * self.chunk_seconds
|
| 80 |
+
|
| 81 |
+
if self.pad_last_chunk_to_full:
|
| 82 |
+
# pad audio so that the last chunk is not dropped
|
| 83 |
+
if self.max_overlap_frac > 0:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"pad_last_chunk_to_full is only supported for non-overlapping windows"
|
| 86 |
+
)
|
| 87 |
+
num_chunks = np.ceil(len(audio) / chunk_samples)
|
| 88 |
+
pad_size = int(num_chunks * chunk_samples - len(audio))
|
| 89 |
+
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 90 |
+
|
| 91 |
+
overflow_len = len(audio) - chunk_samples
|
| 92 |
+
|
| 93 |
+
min_hop_samples = int(
|
| 94 |
+
(1 - self.max_overlap_frac) * chunk_samples
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
n_windows = 1 + overflow_len // min_hop_samples
|
| 98 |
+
window_starts = np.linspace(0, overflow_len, max(n_windows, 1)).astype(int)
|
| 99 |
+
|
| 100 |
+
features = self.preprocessor(
|
| 101 |
+
[
|
| 102 |
+
audio[start : start + chunk_samples].numpy(force=True)
|
| 103 |
+
for start in window_starts
|
| 104 |
+
],
|
| 105 |
+
return_tensors="pt",
|
| 106 |
+
sampling_rate=EXPECTED_SAMPLE_RATE,
|
| 107 |
+
do_normalize=self.normalize_features,
|
| 108 |
+
)
|
| 109 |
+
for key in ("input_features", "input_values"):
|
| 110 |
+
if hasattr(features, key):
|
| 111 |
+
features = getattr(features, key)
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
mean_features = torch.mean(features, dim=-1)
|
| 115 |
+
# features are [batch, n_logmel_bins, n_frames]
|
| 116 |
+
rescale_factor = logmel_energies.unsqueeze(0) - mean_features
|
| 117 |
+
rescale_factor = rescale_factor.unsqueeze(2)
|
| 118 |
+
features += rescale_factor
|
| 119 |
+
return features
|
model.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Mapping, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from peft import LoraConfig, get_peft_model
|
| 5 |
+
from transformers import WhisperConfig, WhisperModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class WhisperEncoderBackbone(torch.nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
model: str = "openai/whisper-small.en",
|
| 12 |
+
hf_config: Optional[Mapping[str, Any]] = None,
|
| 13 |
+
lora_params: Optional[Mapping[str, Any]] = None,
|
| 14 |
+
):
|
| 15 |
+
"""Whisper encoder model with optional Low-Rank Adaptation.
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
model: Name of WhisperModel whose encoder to load from HuggingFace
|
| 20 |
+
hf_config: Optional config for HuggingFace model
|
| 21 |
+
lora_params: Parameters for Low-Rank Adaptation
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
super().__init__()
|
| 25 |
+
hf_config = hf_config if hf_config is not None else dict()
|
| 26 |
+
backbone_config = WhisperConfig.from_pretrained(model, **hf_config)
|
| 27 |
+
self.backbone = (
|
| 28 |
+
WhisperModel.from_pretrained(
|
| 29 |
+
model,
|
| 30 |
+
config=backbone_config,
|
| 31 |
+
)
|
| 32 |
+
.get_encoder()
|
| 33 |
+
.train()
|
| 34 |
+
)
|
| 35 |
+
if lora_params is not None and len(lora_params) > 0:
|
| 36 |
+
lora_config = LoraConfig(**lora_params)
|
| 37 |
+
self.backbone = get_peft_model(self.backbone, lora_config)
|
| 38 |
+
self.backbone_dim = backbone_config.hidden_size
|
| 39 |
+
|
| 40 |
+
def forward(self, whisper_feature_batch):
|
| 41 |
+
return self.backbone(whisper_feature_batch).last_hidden_state.mean(dim=1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SharedLayers(torch.nn.Module):
|
| 45 |
+
def __init__(self, input_dim: int, proj_dims: list[int]):
|
| 46 |
+
"""Fully connected network with Mish nonlinearities between linear layers. No nonlinearity at input or output.
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
input_dim: Dimension of input features
|
| 51 |
+
proj_dims: Dimensions of layers to create
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
modules = []
|
| 56 |
+
for output_dim in proj_dims[:-1]:
|
| 57 |
+
modules.extend([torch.nn.Linear(input_dim, output_dim), torch.nn.Mish()])
|
| 58 |
+
input_dim = output_dim
|
| 59 |
+
modules.append(torch.nn.Linear(input_dim, proj_dims[-1]))
|
| 60 |
+
self.shared_layers = torch.nn.Sequential(*modules)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
return self.shared_layers(x)
|
| 64 |
+
|
| 65 |
+
class TaskHead(torch.nn.Module):
|
| 66 |
+
def __init__(self, input_dim: int, proj_dim: int, dropout: float = 0.0):
|
| 67 |
+
"""Fully connected network with one hidden layer, dropout, and a scalar output."""
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
self.linear = torch.nn.Linear(input_dim, proj_dim)
|
| 71 |
+
self.activation = torch.nn.Mish()
|
| 72 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 73 |
+
self.final_layer = torch.nn.Linear(proj_dim, 1, bias=False)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
x = self.linear(x)
|
| 77 |
+
x = self.activation(x)
|
| 78 |
+
x = self.dropout(x)
|
| 79 |
+
x = self.final_layer(x)
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MultitaskHead(torch.nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
backbone_dim: int,
|
| 87 |
+
shared_projection_dim: list[int],
|
| 88 |
+
tasks: Mapping[str, Mapping[str, Any]],
|
| 89 |
+
):
|
| 90 |
+
"""Fully connected network with multiple named scalar outputs."""
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
# Initialize the shared network and task-specific networks
|
| 94 |
+
self.shared_layers = SharedLayers(backbone_dim, shared_projection_dim)
|
| 95 |
+
self.classifier_head = torch.nn.ModuleDict(
|
| 96 |
+
{
|
| 97 |
+
task: TaskHead(shared_projection_dim[-1], **task_config)
|
| 98 |
+
for task, task_config in tasks.items()
|
| 99 |
+
}
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
x = self.shared_layers(x)
|
| 104 |
+
return {task: head(x) for task, head in self.classifier_head.items()}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def average_tensor_in_segments(tensor: torch.Tensor, lengths: list[int] | torch.Tensor):
|
| 108 |
+
"""Average segments of a `tensor` along dimension 0 based on a list of `lengths`
|
| 109 |
+
|
| 110 |
+
For example, with input tensor `t` and `lengths` [1, 3, 2], the output would be
|
| 111 |
+
[t[0], (t[1] + t[2] + t[3]) / 3, (t[4] + t[5]) / 2]
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
tensor : torch.Tensor
|
| 116 |
+
The tensor to average
|
| 117 |
+
lengths : list of ints
|
| 118 |
+
The lengths of each segment to average in the tensor, in order
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
torch.Tensor
|
| 123 |
+
The tensor with relevant segments averaged
|
| 124 |
+
"""
|
| 125 |
+
if not torch.is_tensor(lengths):
|
| 126 |
+
lengths = torch.tensor(lengths, device=tensor.device)
|
| 127 |
+
index = torch.repeat_interleave(
|
| 128 |
+
torch.arange(len(lengths), device=tensor.device), lengths
|
| 129 |
+
)
|
| 130 |
+
out = torch.zeros(
|
| 131 |
+
lengths.shape + tensor.shape[1:], device=tensor.device, dtype=tensor.dtype
|
| 132 |
+
)
|
| 133 |
+
out.index_add_(0, index, tensor)
|
| 134 |
+
broadcastable_lengths = lengths.view((-1,) + (1,) * (len(out.shape) - 1))
|
| 135 |
+
return out / broadcastable_lengths
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Classifier(torch.nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
backbone_configs: Mapping[str, Mapping[str, Any]],
|
| 142 |
+
classifier_config: Mapping[str, Any],
|
| 143 |
+
inference_thresholds: Mapping[str, Any],
|
| 144 |
+
preprocessor_config: Mapping[str, Any],
|
| 145 |
+
):
|
| 146 |
+
"""Full Kintsugi Depression and Anxiety model.
|
| 147 |
+
|
| 148 |
+
Whisper encoder -> Mean pooling over time -> Layers shared across tasks -> Per-task heads
|
| 149 |
+
|
| 150 |
+
Parameters
|
| 151 |
+
----------
|
| 152 |
+
backbone_configs:
|
| 153 |
+
classifier_config:
|
| 154 |
+
inference_thresholds:
|
| 155 |
+
preprocessor_config:
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
super().__init__()
|
| 159 |
+
|
| 160 |
+
self.backbone = torch.nn.ModuleDict(
|
| 161 |
+
{
|
| 162 |
+
key: WhisperEncoderBackbone(**backbone_configs[key])
|
| 163 |
+
for key in sorted(backbone_configs.keys())
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
backbone_dim = sum(layer.backbone_dim for layer in self.backbone.values())
|
| 168 |
+
self.head = MultitaskHead(backbone_dim, **classifier_config)
|
| 169 |
+
self.inference_thresholds = inference_thresholds
|
| 170 |
+
self.preprocessor_config = preprocessor_config
|
| 171 |
+
|
| 172 |
+
def forward(self, x, lengths):
|
| 173 |
+
backbone_outputs = {
|
| 174 |
+
key: average_tensor_in_segments(layer(x), lengths)
|
| 175 |
+
for key, layer in self.backbone.items()
|
| 176 |
+
}
|
| 177 |
+
backbone_output = torch.cat(list(backbone_outputs.values()), dim=1)
|
| 178 |
+
return self.head(backbone_output), torch.ones_like(lengths)
|
| 179 |
+
|
| 180 |
+
def quantize_scores(self, scores: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
|
| 181 |
+
"""Map per-task scores to discrete predictions per `inference_thresholds` config."""
|
| 182 |
+
return {
|
| 183 |
+
key: torch.searchsorted(torch.tensor(self.inference_thresholds[key], device=value.device), value.mean(), out_int32=True)
|
| 184 |
+
for key, value in scores.items()
|
| 185 |
+
}
|
pipeline.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, BinaryIO, Mapping, Optional, Union
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from config import default_config
|
| 7 |
+
from featex import load_audio, Preprocessor
|
| 8 |
+
from model import Classifier
|
| 9 |
+
|
| 10 |
+
class Pipeline:
|
| 11 |
+
def __init__(self, checkpoint: Optional[str | Path] = None, config: Optional[Mapping[str, Any]] = None, device: Optional[torch.device] = None):
|
| 12 |
+
if checkpoint is None:
|
| 13 |
+
file_dir = Path(__file__).parent.resolve()
|
| 14 |
+
checkpoint = file_dir / "dam3.1.ckpt"
|
| 15 |
+
if config is None:
|
| 16 |
+
config = default_config
|
| 17 |
+
if device is None:
|
| 18 |
+
if torch.cuda.is_available():
|
| 19 |
+
device = torch.device("cuda:0")
|
| 20 |
+
else:
|
| 21 |
+
device = torch.device("cpu")
|
| 22 |
+
self.device = device
|
| 23 |
+
self.model = Classifier(**config)
|
| 24 |
+
self.preprocessor = Preprocessor(**self.model.preprocessor_config)
|
| 25 |
+
state_dict = torch.load(checkpoint, map_location=device)
|
| 26 |
+
self.model.load_state_dict(state_dict)
|
| 27 |
+
self.model.to(self.device)
|
| 28 |
+
self.model.eval()
|
| 29 |
+
|
| 30 |
+
def run_on_features(self, features: torch.Tensor, quantize: bool = True):
|
| 31 |
+
scores = self.model(features, torch.tensor([features.shape[0]], device=self.device))[0]
|
| 32 |
+
if quantize:
|
| 33 |
+
return {k: int(v.item()) for k, v in self.model.quantize_scores(scores).items()}
|
| 34 |
+
else:
|
| 35 |
+
return scores
|
| 36 |
+
|
| 37 |
+
def run_on_audio(self, audio: torch.Tensor, quantize: bool = True):
|
| 38 |
+
features = self.preprocessor.preprocess_with_audio_normalization(audio)
|
| 39 |
+
return self.run_on_features(features.to(self.device), quantize=quantize)
|
| 40 |
+
|
| 41 |
+
def run_on_file(self, source: Union[BinaryIO, str, os.PathLike], quantize=True):
|
| 42 |
+
audio = load_audio(source)
|
| 43 |
+
return self.run_on_audio(audio, quantize=quantize)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytorch~=2.6.0
|
| 2 |
+
pysoundfile~=0.13.1
|
| 3 |
+
torchaudio~=2.7.0
|
| 4 |
+
transformers~=4.52.3
|
| 5 |
+
peft~=0.15.2
|
tuning/__init__.py
ADDED
|
File without changes
|
tuning/indet_roc.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tools for tuning pairs of scalar thresholds to trade off sensitivity, specificity, and indeterminate rate.
|
| 2 |
+
|
| 3 |
+
See `IndetSnSpArray` and subclass docstrings for details.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
from dataclasses import asdict, dataclass
|
| 7 |
+
from typing import NamedTuple, Optional
|
| 8 |
+
from typing_extensions import Self # in typing in python3.11
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def running_argmax_indices(a):
|
| 14 |
+
"""Return indices of a where the value is larger than all previous values.
|
| 15 |
+
|
| 16 |
+
>>> running_argmax_indices([1, 0, 3, 4, 4, 2, 5, 7, 1])
|
| 17 |
+
array([0, 2, 3, 6, 7])
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
m = np.maximum.accumulate(a)
|
| 21 |
+
return np.flatnonzero(np.r_[True, m[:-1] < m[1:]])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pareto_2d_indices(x, y):
|
| 25 |
+
"""Compute indices of the Pareto frontier maximizing x and y, sorted in increasing x and decreasing y.
|
| 26 |
+
|
| 27 |
+
e.g. the Pareto frontier of the point set below is [A, G]
|
| 28 |
+
|
| 29 |
+
B A
|
| 30 |
+
C
|
| 31 |
+
E D
|
| 32 |
+
F G
|
| 33 |
+
H
|
| 34 |
+
|
| 35 |
+
>>> u = [2, 0, 1, 2, 0, 0, 3, 1]
|
| 36 |
+
>>> v = [4, 4, 3, 2, 2, 1, 1, 0]
|
| 37 |
+
>>> pareto_2d_indices(np.array(u), np.array(v))
|
| 38 |
+
array([0, 6])
|
| 39 |
+
|
| 40 |
+
"""
|
| 41 |
+
sort_indices = np.lexsort((-x, -y)) # last element is primary sort key
|
| 42 |
+
return sort_indices[running_argmax_indices(x[sort_indices])]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def midpoints_with_infs(x):
|
| 46 |
+
"""Return the midpoints between the sorted unique elements of x, along with +/-inf."""
|
| 47 |
+
unique_scores = np.unique(np.r_[-np.inf, x, np.inf])
|
| 48 |
+
return (unique_scores[1:] + unique_scores[:-1]) / 2
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def kde_disc_mass(
|
| 52 |
+
data: np.ndarray,
|
| 53 |
+
edges: np.ndarray,
|
| 54 |
+
bandwidth: float,
|
| 55 |
+
weights: Optional[np.ndarray] = None,
|
| 56 |
+
):
|
| 57 |
+
"""Perform Kernel Density Estimation (KDE) on data & weights, then compute the probability mass between edges."""
|
| 58 |
+
import scipy
|
| 59 |
+
z_score = (edges[:, None] - data[None, :]) / bandwidth
|
| 60 |
+
component_cdfs = scipy.stats.norm.cdf(z_score)
|
| 61 |
+
if weights is None:
|
| 62 |
+
weights = np.ones_like(data)
|
| 63 |
+
cdf = np.dot(component_cdfs, weights / weights.sum())
|
| 64 |
+
return np.diff(cdf)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class BinaryLabeledScores(NamedTuple):
|
| 68 |
+
"""An array of numeric scores along with associated 0/1 ground truth and optional numeric weights."""
|
| 69 |
+
|
| 70 |
+
y_score: np.ndarray
|
| 71 |
+
y_true: np.ndarray
|
| 72 |
+
weights: Optional[np.ndarray] = None
|
| 73 |
+
|
| 74 |
+
def smooth(
|
| 75 |
+
self, num_points: int, bandwidth: float, padding_bandwidths: float = 5.0
|
| 76 |
+
) -> "BinaryLabeledScores":
|
| 77 |
+
"""KDE-smooth positive and negative scores separately and discretize each to equally spaced weighted points.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
num_points: number of points to use each for positive and negative score discretizations
|
| 81 |
+
bandwidth: bandwidth of kernel density estimation, i.e. standard deviation of noise to be added
|
| 82 |
+
padding_bandwidths: number of bandwidths to extend past lowest and highest scores when selecting
|
| 83 |
+
discretization endpoints
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
`BinaryLabeledScores` object representing the smoothed and re-discretized weighted labeled scores
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
pos = self.y_true == 1
|
| 90 |
+
neg = self.y_true == 0
|
| 91 |
+
if self.weights is not None:
|
| 92 |
+
pos_weights = self.weights[pos]
|
| 93 |
+
neg_weights = self.weights[neg]
|
| 94 |
+
else:
|
| 95 |
+
pos_weights = None
|
| 96 |
+
neg_weights = None
|
| 97 |
+
padding = padding_bandwidths * bandwidth
|
| 98 |
+
all_points = np.linspace(
|
| 99 |
+
self.y_score.min() - padding,
|
| 100 |
+
self.y_score.max() + padding,
|
| 101 |
+
2 * num_points + 1,
|
| 102 |
+
)
|
| 103 |
+
edges = all_points[::2]
|
| 104 |
+
centers = all_points[1::2]
|
| 105 |
+
pos_kde_weights = kde_disc_mass(
|
| 106 |
+
self.y_score[pos], edges, bandwidth, pos_weights
|
| 107 |
+
)
|
| 108 |
+
neg_kde_weights = kde_disc_mass(
|
| 109 |
+
self.y_score[neg], edges, bandwidth, neg_weights
|
| 110 |
+
)
|
| 111 |
+
return BinaryLabeledScores(
|
| 112 |
+
y_true=np.r_[np.zeros_like(centers), np.ones_like(centers)],
|
| 113 |
+
y_score=np.r_[centers, centers],
|
| 114 |
+
weights=np.r_[neg_kde_weights, pos_kde_weights],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def indet_sn_sp_array(self) -> "IndetSnSpArray":
|
| 118 |
+
"""Build `IndetSnSpArray`."""
|
| 119 |
+
return IndetSnSpArray.build(**self._asdict())
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def fake_vectorized_binom_ci(
|
| 123 |
+
k: np.ndarray, n: np.ndarray, p: float | np.ndarray = 0.95
|
| 124 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 125 |
+
"""Compute binomial confidence intervals on arrays of parameters inefficiently."""
|
| 126 |
+
import scipy
|
| 127 |
+
k, n, p = np.broadcast_arrays(k, n, p)
|
| 128 |
+
# If speed is needed this can be rewritten with the statsmodels package, which is vectorized.
|
| 129 |
+
flat_out = [
|
| 130 |
+
scipy.stats.binomtest(k_, n_).proportion_ci(p_)
|
| 131 |
+
for k_, n_, p_ in zip(k.flatten(), n.flatten(), p.flatten())
|
| 132 |
+
]
|
| 133 |
+
low = np.array([ci.low for ci in flat_out]).reshape(k.shape)
|
| 134 |
+
high = np.array([ci.high for ci in flat_out]).reshape(k.shape)
|
| 135 |
+
return low, high
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class IndetSnSpArray:
|
| 140 |
+
"""An array of metrics at different lower and upper threshold values.
|
| 141 |
+
|
| 142 |
+
This class and subclasses are for selecting pairs of model thresholds based on sensitivity,
|
| 143 |
+
specificity, and indeterminate rate. Throughout we assume scores are scalars and ground truth is binary.
|
| 144 |
+
|
| 145 |
+
Each member `lower_thresh`, `upper_thresh`, `sn`, `sp`, and `indet_frac` must be a numpy array, and they all must
|
| 146 |
+
have the same shape. Corresponding entries of these arrays specify a pair of thresholds and the metrics when a
|
| 147 |
+
common dataset is evaluated using those thresholds. The thresholding logic is that scores less than the lower
|
| 148 |
+
threshold count as negative outputs, scores greater than or equal to the upper threshold count as positive outputs,
|
| 149 |
+
and scores in between are indeterminate outputs.
|
| 150 |
+
|
| 151 |
+
Indeterminate fraction is defined as the proportion of scores in between the two thresholds. All other
|
| 152 |
+
metrics are interpreted as conditioned on the scores not being indeterminate. For example, sensitivity
|
| 153 |
+
is defined as usual as (true positives) / (total positives) *except that examples with indeterminate
|
| 154 |
+
scores do not count towards the numerator or the denominator*.
|
| 155 |
+
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
lower_thresh: np.ndarray
|
| 159 |
+
upper_thresh: np.ndarray
|
| 160 |
+
tp: np.ndarray
|
| 161 |
+
fp: np.ndarray
|
| 162 |
+
tn: np.ndarray
|
| 163 |
+
fn: np.ndarray
|
| 164 |
+
indet: np.ndarray
|
| 165 |
+
weighted: bool = False
|
| 166 |
+
min_weight: float = 1.0
|
| 167 |
+
eps: float = 1e-8
|
| 168 |
+
|
| 169 |
+
def __post_init__(self):
|
| 170 |
+
self.sn = self.tp / np.maximum(self.tp + self.fn, self.min_weight)
|
| 171 |
+
self.sp = self.tn / np.maximum(self.tn + self.fp, self.min_weight)
|
| 172 |
+
self.ppv = self.tp / np.maximum(self.tp + self.fp, self.min_weight)
|
| 173 |
+
self.npv = self.tn / np.maximum(self.tn + self.fn, self.min_weight)
|
| 174 |
+
total = self.indet + self.fn + self.fp + self.tn + self.tp
|
| 175 |
+
self.indet_frac = self.indet / np.maximum(total, self.min_weight)
|
| 176 |
+
for attr in ("sn", "sp", "ppv", "npv", "indet_frac"):
|
| 177 |
+
value = getattr(self, attr)
|
| 178 |
+
if value.size:
|
| 179 |
+
if value.max() > 1 + self.eps:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"Numerical precision issues produced invalid value {attr} = {value.max()}."
|
| 182 |
+
)
|
| 183 |
+
if value.min() < -self.eps:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
f"Numerical precision issues produced invalid value {attr} = {value.min()}."
|
| 186 |
+
)
|
| 187 |
+
setattr(self, attr, np.clip(value, 0.0, 1.0))
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def min_sn_sp(self):
|
| 191 |
+
return np.minimum(self.sn, self.sp)
|
| 192 |
+
|
| 193 |
+
@classmethod
|
| 194 |
+
def build(
|
| 195 |
+
cls,
|
| 196 |
+
lower_thresh: Optional[np.ndarray] = None,
|
| 197 |
+
upper_thresh: Optional[np.ndarray] = None,
|
| 198 |
+
*,
|
| 199 |
+
y_true: np.ndarray,
|
| 200 |
+
y_score: np.ndarray,
|
| 201 |
+
weights: Optional[np.ndarray] = None,
|
| 202 |
+
eps: float = 1e-8,
|
| 203 |
+
) -> Self:
|
| 204 |
+
"""Find `IndetSnSpArray` values for given truth and scores as thresholds vary (à la sklearn.metrics.roc_curve).
|
| 205 |
+
|
| 206 |
+
The output object contains arrays for `sn`, `sp`, `indet_frac`, `lower_thresh`, and `upper_thresh`, all with the
|
| 207 |
+
same shape. What these arrays contain and what their common shape is depends on the input as follows.
|
| 208 |
+
|
| 209 |
+
If both lower_thresh and upper_thresh are provided, they must have the same shape and this method computes
|
| 210 |
+
metrics at the pairs given by corresponding entries in these arrays. The common output shape will be the same as
|
| 211 |
+
this common input shape.
|
| 212 |
+
|
| 213 |
+
If only one set of thresholds is provided, this method computes metrics at all sorted pairs of these thresholds
|
| 214 |
+
(along with +/- inf). If neither is provided, sort scores and allow thresholds between each pair (along with
|
| 215 |
+
+/- inf). In both of these cases, the common output shape is a 1-d vector of length equal to the number of such
|
| 216 |
+
pairs.
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
weights = weights if weights is not None else np.ones_like(y_true)
|
| 220 |
+
y_true = y_true[weights > 0]
|
| 221 |
+
y_score = y_score[weights > 0]
|
| 222 |
+
weights = weights[weights > 0]
|
| 223 |
+
|
| 224 |
+
# Find all threshes and include +/- inf so np.histogram does the right thing
|
| 225 |
+
if lower_thresh is not None and upper_thresh is not None:
|
| 226 |
+
threshes = np.unique(np.r_[-np.inf, lower_thresh, upper_thresh, np.inf])
|
| 227 |
+
lower_indices = np.searchsorted(threshes, lower_thresh)
|
| 228 |
+
upper_indices = np.searchsorted(threshes, upper_thresh)
|
| 229 |
+
else:
|
| 230 |
+
if lower_thresh is not None:
|
| 231 |
+
threshes = np.unique(np.r_[-np.inf, lower_thresh, np.inf])
|
| 232 |
+
elif upper_thresh is not None:
|
| 233 |
+
threshes = np.unique(np.r_[-np.inf, upper_thresh, np.inf])
|
| 234 |
+
else:
|
| 235 |
+
unique_scores = np.unique(np.r_[-np.inf, y_score, np.inf])
|
| 236 |
+
threshes = (unique_scores[1:] + unique_scores[:-1]) / 2
|
| 237 |
+
threshes = np.unique(threshes)
|
| 238 |
+
lower_indices, upper_indices = np.triu_indices(len(threshes))
|
| 239 |
+
|
| 240 |
+
count_by_bin = np.histogram(y_score, bins=threshes, weights=weights)[0]
|
| 241 |
+
pos_by_bin = np.histogram(y_score, bins=threshes, weights=y_true * weights)[0]
|
| 242 |
+
count_by_thresh = np.pad(np.cumsum(count_by_bin), (1, 0))
|
| 243 |
+
pos_by_thresh = np.pad(np.cumsum(pos_by_bin), (1, 0))
|
| 244 |
+
tn_plus_fn = count_by_thresh[lower_indices]
|
| 245 |
+
total_minus_tp_minus_fp = count_by_thresh[upper_indices]
|
| 246 |
+
tp_plus_fp = count_by_thresh[-1] - total_minus_tp_minus_fp
|
| 247 |
+
fn = pos_by_thresh[lower_indices]
|
| 248 |
+
total_pos = pos_by_thresh[-1] # last thresh is +inf
|
| 249 |
+
tp = total_pos - pos_by_thresh[upper_indices]
|
| 250 |
+
fp = tp_plus_fp - tp
|
| 251 |
+
tn = tn_plus_fn - fn
|
| 252 |
+
min_weight = weights.min()
|
| 253 |
+
indet = total_minus_tp_minus_fp - tn_plus_fn
|
| 254 |
+
return cls(
|
| 255 |
+
lower_thresh=threshes[lower_indices],
|
| 256 |
+
upper_thresh=threshes[upper_indices],
|
| 257 |
+
tp=tp,
|
| 258 |
+
fp=fp,
|
| 259 |
+
tn=tn,
|
| 260 |
+
fn=fn,
|
| 261 |
+
indet=indet,
|
| 262 |
+
weighted=not all(weights == 1.0),
|
| 263 |
+
min_weight=min_weight,
|
| 264 |
+
eps=eps,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def eval(
|
| 268 |
+
self,
|
| 269 |
+
*,
|
| 270 |
+
y_true,
|
| 271 |
+
y_score,
|
| 272 |
+
weights: Optional[np.ndarray] = None,
|
| 273 |
+
) -> "IndetSnSpArray":
|
| 274 |
+
"""Evaluate the given data on the thresholds of `self`."""
|
| 275 |
+
return IndetSnSpArray.build(
|
| 276 |
+
lower_thresh=self.lower_thresh,
|
| 277 |
+
upper_thresh=self.upper_thresh,
|
| 278 |
+
y_true=y_true,
|
| 279 |
+
y_score=y_score,
|
| 280 |
+
weights=weights,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def __getitem__(self, item) -> "IndetSnSpArray":
|
| 284 |
+
"""Extract a subarray with numpy-style indexing."""
|
| 285 |
+
return IndetSnSpArray(
|
| 286 |
+
lower_thresh=self.lower_thresh[item],
|
| 287 |
+
upper_thresh=self.upper_thresh[item],
|
| 288 |
+
tp=self.tp[item],
|
| 289 |
+
fp=self.fp[item],
|
| 290 |
+
fn=self.fn[item],
|
| 291 |
+
tn=self.tn[item],
|
| 292 |
+
indet=self.indet[item],
|
| 293 |
+
weighted=self.weighted,
|
| 294 |
+
min_weight=self.min_weight,
|
| 295 |
+
eps=self.eps,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def __add__(self, other: "IndetSnSpArray") -> "IndetSnSpArray":
|
| 299 |
+
if not isinstance(other, IndetSnSpArray):
|
| 300 |
+
raise TypeError(f"Cannot add {type(other)} to IndetSnSpArray.")
|
| 301 |
+
tp = self.tp + other.tp
|
| 302 |
+
if np.array_equal(self.lower_thresh, other.lower_thresh):
|
| 303 |
+
lower_thresh = self.lower_thresh
|
| 304 |
+
else:
|
| 305 |
+
lower_thresh = np.nan * np.ones_like(tp)
|
| 306 |
+
if np.array_equal(self.upper_thresh, other.upper_thresh):
|
| 307 |
+
upper_thresh = self.upper_thresh
|
| 308 |
+
else:
|
| 309 |
+
upper_thresh = np.nan * np.ones_like(tp)
|
| 310 |
+
return IndetSnSpArray(
|
| 311 |
+
lower_thresh=lower_thresh,
|
| 312 |
+
upper_thresh=upper_thresh,
|
| 313 |
+
tp=tp,
|
| 314 |
+
fp=self.fp + other.fp,
|
| 315 |
+
fn=self.fn + other.fn,
|
| 316 |
+
tn=self.tn + other.tn,
|
| 317 |
+
indet=self.indet + other.indet,
|
| 318 |
+
weighted=self.weighted or other.weighted,
|
| 319 |
+
min_weight=min(self.min_weight, other.min_weight),
|
| 320 |
+
eps=self.eps,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def confidence_interval_bound(self, p: float = 0.95) -> "IndetSnSpArray":
|
| 324 |
+
"""Compute two-sided confidence interval bounds: upper for indet_frac and lower for other metrics."""
|
| 325 |
+
if self.weighted:
|
| 326 |
+
raise NotImplementedError(
|
| 327 |
+
"Confidence intervals only implemented for unweighted confusion matrices."
|
| 328 |
+
)
|
| 329 |
+
copy = IndetSnSpArray(**asdict(self))
|
| 330 |
+
copy.sn, _ = fake_vectorized_binom_ci(self.tp, self.tp + self.fn, p=p)
|
| 331 |
+
copy.sp, _ = fake_vectorized_binom_ci(self.tn, self.tn + self.fp, p=p)
|
| 332 |
+
copy.ppv, _ = fake_vectorized_binom_ci(self.tp, self.fp + self.tp, p=p)
|
| 333 |
+
copy.npv, _ = fake_vectorized_binom_ci(self.tn, self.tn + self.fn, p=p)
|
| 334 |
+
_, copy.indet_frac = fake_vectorized_binom_ci(
|
| 335 |
+
self.indet, self.tp + self.fn + self.fp + self.tn + self.indet, p=p
|
| 336 |
+
)
|
| 337 |
+
return copy
|
| 338 |
+
|
| 339 |
+
def roc_curve(self, indet_budget=0.0) -> "IndetRocCurve":
|
| 340 |
+
"""Compute ROC curve with indeterminate budget, sorted by increasing sn and decreasing sp.
|
| 341 |
+
|
| 342 |
+
Restrict `self` to Pareto-optimal pairs (sn, sp) for which `indet_frac <= indet_budget`. Other points are worse
|
| 343 |
+
than the points on the curve in the sense of having worse Sn, worse Sp, or not meeting the indeterminate budget.
|
| 344 |
+
|
| 345 |
+
"""
|
| 346 |
+
within_budget = self[self.indet_frac <= indet_budget]
|
| 347 |
+
frontier = pareto_2d_indices(within_budget.sn, within_budget.sp)
|
| 348 |
+
return IndetRocCurve(**asdict(within_budget[frontier]))
|
| 349 |
+
|
| 350 |
+
def sn_eq_sp_graph(self) -> "IndetSnEqSpGraph":
|
| 351 |
+
"""Compute sn=sp as a function of indet_frac, returning both sorted in increasing order.
|
| 352 |
+
|
| 353 |
+
Method: restrict to Pareto-optimal pairs (s, indet_frac) where s = min(sn, sp).
|
| 354 |
+
|
| 355 |
+
Pareto-optimality means that if (s, indet_frac) is in the output, there is no point (sn', sp', indet_frac') in
|
| 356 |
+
the input with indet_frac' <= indet_frac and sn', sp' > s. In other words, s is the maximum value such that
|
| 357 |
+
the quadrant { sn, sp >= s} intersects `self.roc_curve(indet_frac)`. This maximum occurs where the ROC curve
|
| 358 |
+
intersects the diagonal, up to an error bounded by the distance between points on the ROC curve.
|
| 359 |
+
|
| 360 |
+
"""
|
| 361 |
+
frontier = pareto_2d_indices(self.min_sn_sp, -self.indet_frac)
|
| 362 |
+
return IndetSnEqSpGraph(**asdict(self[frontier]))
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class IndetRocCurve(IndetSnSpArray):
|
| 366 |
+
"""Sn, Sp achievable within some indeterminate budget and associated lower and upper thresholds.
|
| 367 |
+
|
| 368 |
+
`sn` is assumed to be sorted in increasing order and `sp` decreasing.
|
| 369 |
+
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
def sn_eq_sp(self) -> IndetSnSpArray:
|
| 373 |
+
"""Locate the point on the ROC curve closest to the diagonal"""
|
| 374 |
+
return self[np.argmax(self.min_sn_sp)]
|
| 375 |
+
|
| 376 |
+
def auc(self) -> float:
|
| 377 |
+
"""Compute the area under the ROC curve."""
|
| 378 |
+
# `auc` does not automatically include the trivial points (0, 1) and (1, 0)
|
| 379 |
+
# and will underestimate the AUC if these are not explicitly added
|
| 380 |
+
from sklearn.metrics import auc
|
| 381 |
+
return auc(1 - np.r_[1.0, self.sp, 0.0], np.r_[0.0, self.sn, 1.0])
|
| 382 |
+
|
| 383 |
+
@classmethod
|
| 384 |
+
def build(
|
| 385 |
+
cls,
|
| 386 |
+
thresh=None,
|
| 387 |
+
*,
|
| 388 |
+
y_true: np.ndarray,
|
| 389 |
+
y_score: np.ndarray,
|
| 390 |
+
weights: Optional[np.ndarray] = None,
|
| 391 |
+
) -> Self:
|
| 392 |
+
"""Build an indeterminate=0 ROC curve in n log n time (vs n**2 for IndetSnSpArray.build().roc_curve())."""
|
| 393 |
+
if thresh is None:
|
| 394 |
+
thresh = midpoints_with_infs(y_score)[
|
| 395 |
+
::-1
|
| 396 |
+
] # reverse for proper output sorting
|
| 397 |
+
issa = IndetSnSpArray.build(
|
| 398 |
+
lower_thresh=thresh,
|
| 399 |
+
upper_thresh=thresh,
|
| 400 |
+
y_true=y_true,
|
| 401 |
+
y_score=y_score,
|
| 402 |
+
weights=weights,
|
| 403 |
+
)
|
| 404 |
+
return cls(**asdict(issa))
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class IndetSnEqSpGraph(IndetSnSpArray):
|
| 408 |
+
"""Sn=Sp achievable as a function of indeterminate budget and associated lower and upper thresholds.
|
| 409 |
+
|
| 410 |
+
Both min(self.sn, self.sp) and self.indet_frac are assumed to be sorted in non-decreasing order.
|
| 411 |
+
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
def at_budget(self, indet_budget: float = 0.0) -> IndetSnSpArray:
|
| 415 |
+
"""Locate the best point on the graph within the given budget."""
|
| 416 |
+
return self[np.searchsorted(self.indet_frac, indet_budget, side="right") - 1]
|
tuning/optimal_ordinal.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tools for choosing multiple thresholds optimally under various decision criteria via dynamic programming.
|
| 2 |
+
|
| 3 |
+
The abstract machinery is contained in the classes:
|
| 4 |
+
- `OrdinalThresholding`
|
| 5 |
+
- `OptimalOrdinalThresholdingViaDynamicProgramming`
|
| 6 |
+
- `OptimalCostPerSampleOrdinalThresholding`
|
| 7 |
+
- `ClassWeightedOptimalCostPerSampleOrdinalThresholding`
|
| 8 |
+
- `OptimalCostPerClassOrdinalThresholding`
|
| 9 |
+
These can be subclassed to efficiently implement new decision criteria, depending on their structure.
|
| 10 |
+
|
| 11 |
+
The main intended user-facing classes are the subclasses implementing different decision criteria:
|
| 12 |
+
- `MaxAccuracyOrdinalThresholding`
|
| 13 |
+
- `MaxMacroRecallOrdinalThresholding`
|
| 14 |
+
- `MinAbsoluteErrorOrdinalThresholding`
|
| 15 |
+
- `MaxMacroPrecisionOrdinalThresholding`
|
| 16 |
+
- `MaxMacroF1OrdinalThresholding`
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from abc import ABC, abstractmethod
|
| 21 |
+
from typing import Literal, Optional, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class OrdinalThresholding(torch.nn.Module):
|
| 27 |
+
"""Basic 1d thresholding logic."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, num_classes: int):
|
| 30 |
+
"""Init thresholding module with the specified number of classes (one more than the number of thresholds)."""
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.num_classes = num_classes
|
| 33 |
+
self.register_buffer("thresholds", torch.zeros(num_classes - 1))
|
| 34 |
+
self.thresholds: torch.Tensor
|
| 35 |
+
|
| 36 |
+
def is_valid(self) -> bool:
|
| 37 |
+
"""Check whether the thresholds are monotone non-decreasing."""
|
| 38 |
+
return all(torch.greater_equal(self.thresholds[1:], self.thresholds[:-1]))
|
| 39 |
+
|
| 40 |
+
def forward(self, scores) -> torch.Tensor:
|
| 41 |
+
"""Find which thresholds each score lies between."""
|
| 42 |
+
return torch.searchsorted(self.thresholds, scores)
|
| 43 |
+
|
| 44 |
+
def tune_thresholds(
|
| 45 |
+
self,
|
| 46 |
+
*,
|
| 47 |
+
scores: torch.Tensor,
|
| 48 |
+
labels: torch.Tensor,
|
| 49 |
+
available_thresholds: Optional[torch.Tensor] = None,
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
"""Adapt the thresholds to the given data.
|
| 52 |
+
|
| 53 |
+
This is essentially an abstract method, but for testing purposes it's helpful to be able to instantiate the
|
| 54 |
+
class with a no-op version.
|
| 55 |
+
|
| 56 |
+
Parameters
|
| 57 |
+
----------
|
| 58 |
+
scores : a vector of `float` scores for each example in the validation set
|
| 59 |
+
labels : a vector of `int` labels having the same shape as `scores` containing the corresponding labels
|
| 60 |
+
available_thresholds : a vector of `float` score values over which to optimize choice of thresholds;
|
| 61 |
+
`None`, then thresholds between every score in the validation set are allowed. +/- inf are always allowed.
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
scalar `float` mean cost on the validation set using optimal thresholds
|
| 66 |
+
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class OptimalOrdinalThresholdingViaDynamicProgramming(OrdinalThresholding, ABC):
|
| 71 |
+
"""Super-class for general dynamic programming implementations of ordinal threshold tuning.
|
| 72 |
+
|
| 73 |
+
Subclasses implement different ways of computing the mean cost and corresponding DP step.
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
direction: Literal["min", "max"] # provided by subclasses
|
| 78 |
+
|
| 79 |
+
def __init__(self, num_classes: int):
|
| 80 |
+
super().__init__(num_classes=num_classes)
|
| 81 |
+
if self.direction not in ("min", "max"):
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"Got direction {self.direction!r}, expected 'min' or 'max'."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def mean_cost(
|
| 88 |
+
self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]
|
| 89 |
+
) -> torch.Tensor:
|
| 90 |
+
"""Compute the mean cost of assigning label(s) `preds` when the ground truth is `labels`."""
|
| 91 |
+
|
| 92 |
+
def best_constant_output_classifier(self, labels: torch.Tensor):
|
| 93 |
+
"""Find the optimal mean cost of a constant-output classifier for given `labels` and the associated constant."""
|
| 94 |
+
if self.direction == "min":
|
| 95 |
+
optimize = torch.min
|
| 96 |
+
else:
|
| 97 |
+
optimize = torch.max
|
| 98 |
+
optimum = optimize(
|
| 99 |
+
torch.tensor(
|
| 100 |
+
[
|
| 101 |
+
self.mean_cost(labels=labels, preds=c)
|
| 102 |
+
for c in range(self.num_classes)
|
| 103 |
+
],
|
| 104 |
+
device=labels.device,
|
| 105 |
+
),
|
| 106 |
+
0,
|
| 107 |
+
)
|
| 108 |
+
return optimum.values, optimum.indices
|
| 109 |
+
|
| 110 |
+
@abstractmethod
|
| 111 |
+
def dp_step(
|
| 112 |
+
self,
|
| 113 |
+
c_idx: int,
|
| 114 |
+
*,
|
| 115 |
+
scores: torch.Tensor,
|
| 116 |
+
labels: torch.Tensor,
|
| 117 |
+
available_thresholds: torch.Tensor,
|
| 118 |
+
prev_cost: Optional[torch.Tensor] = None,
|
| 119 |
+
) -> (torch.Tensor, Optional[torch.Tensor]):
|
| 120 |
+
"""Given optimal cost `prev_cost` of classes < `c_idx`, optimize cost of `c_idx` as a function of threshold.
|
| 121 |
+
|
| 122 |
+
Arguments
|
| 123 |
+
---------
|
| 124 |
+
c_idx : current class index
|
| 125 |
+
scores, labels, available_thresholds : see `tune_thresholds`
|
| 126 |
+
prev_cost (optional float tensor) : optimal cost of classes < `c_idx` as a function of upper threshold
|
| 127 |
+
for class `c_idx - 1`; ignored if `c_idx == 0`
|
| 128 |
+
|
| 129 |
+
Returns
|
| 130 |
+
-------
|
| 131 |
+
cost: `cost[i]` is for choosing upper threshold of class `c_idx` equal to `available_thresholds[i]`
|
| 132 |
+
when thresholds for lower classes are chosen optimally
|
| 133 |
+
indices : to achieve `cost[i]`, optimal upper threshold for class `c_idx - 1` is
|
| 134 |
+
`available_thresholds[indices[i]]`; `None` if `c_idx == 0`
|
| 135 |
+
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def tune_thresholds(
|
| 139 |
+
self,
|
| 140 |
+
*,
|
| 141 |
+
scores: torch.Tensor,
|
| 142 |
+
labels: torch.Tensor,
|
| 143 |
+
available_thresholds: Optional[torch.Tensor] = None,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""Set `self.thresholds` to optimize mean cost of given `scores` and `labels`.
|
| 146 |
+
|
| 147 |
+
Arguments
|
| 148 |
+
---------
|
| 149 |
+
scores (1d float tensor) : scores of examples on tuning dataset
|
| 150 |
+
labels (1d int tensor) : labels in {0, ..., self.num_classes - 1} of same shape as scores
|
| 151 |
+
available_thresholds (optional 1d float tensor) : thresholds which will be considered when tuning.
|
| 152 |
+
+/-inf will be added automatically to ensure all examples are classified. If omitted, will
|
| 153 |
+
insert thresholds between each element of sorted(unique(scores)).
|
| 154 |
+
|
| 155 |
+
Returns
|
| 156 |
+
-------
|
| 157 |
+
float tensor : optimal mean cost achieved on the provided dataset at the tuned `self.thresholds`
|
| 158 |
+
|
| 159 |
+
"""
|
| 160 |
+
inf = torch.tensor([torch.inf], device=scores.device)
|
| 161 |
+
if available_thresholds is None: # use all possible thresholds
|
| 162 |
+
unique_scores = torch.unique(scores)
|
| 163 |
+
available_thresholds = (unique_scores[:-1] + unique_scores[1:]) / 2.0
|
| 164 |
+
# Always allow some classes to be omitted entirely by setting thresholds to +/- inf.
|
| 165 |
+
# This simplifies the algorithm and also guarantees that the baseline constant-output
|
| 166 |
+
# classifiers are feasible choices for tuning, which is needed to assure that the
|
| 167 |
+
# optimum is at least as good as a constant-output classifier.
|
| 168 |
+
available_thresholds = torch.concatenate(
|
| 169 |
+
[
|
| 170 |
+
-inf,
|
| 171 |
+
available_thresholds,
|
| 172 |
+
inf,
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
indices = torch.empty(
|
| 176 |
+
(self.num_classes - 2, len(available_thresholds)),
|
| 177 |
+
dtype=torch.int,
|
| 178 |
+
device=scores.device,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# cost[j] = optimal total cost of items assigned pred <= c if the
|
| 182 |
+
# threshold between class c and c+1 is available_thresholds[j] (by appropriate choice of lower thresholds).
|
| 183 |
+
cost, _ = self.dp_step(
|
| 184 |
+
c_idx=0,
|
| 185 |
+
scores=scores,
|
| 186 |
+
labels=labels,
|
| 187 |
+
available_thresholds=available_thresholds,
|
| 188 |
+
)
|
| 189 |
+
for c in range(1, self.num_classes - 1):
|
| 190 |
+
cost, indices[c - 1, :] = self.dp_step(
|
| 191 |
+
c_idx=c,
|
| 192 |
+
scores=scores,
|
| 193 |
+
labels=labels,
|
| 194 |
+
available_thresholds=available_thresholds,
|
| 195 |
+
prev_cost=cost,
|
| 196 |
+
)
|
| 197 |
+
cost, best_index = self.dp_step(
|
| 198 |
+
c_idx=self.num_classes - 1,
|
| 199 |
+
scores=scores,
|
| 200 |
+
labels=labels,
|
| 201 |
+
available_thresholds=available_thresholds,
|
| 202 |
+
prev_cost=cost,
|
| 203 |
+
)
|
| 204 |
+
if self.direction == "min":
|
| 205 |
+
cost *= -1
|
| 206 |
+
|
| 207 |
+
# Follow DP path backwards to find thresholds which optimized cost
|
| 208 |
+
self.thresholds[self.num_classes - 2] = available_thresholds[
|
| 209 |
+
best_index
|
| 210 |
+
] # final threshold
|
| 211 |
+
for c in range(self.num_classes - 3, -1, -1): # counting down to zero
|
| 212 |
+
best_index = indices[c, best_index.long()]
|
| 213 |
+
self.thresholds[c] = available_thresholds[best_index.long()]
|
| 214 |
+
|
| 215 |
+
return cost
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def cumsum_with_0(t: torch.Tensor):
|
| 219 |
+
return torch.nn.functional.pad(torch.cumsum(t, dim=0), (1, 0))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class OptimalCostPerSampleOrdinalThresholding(
|
| 223 |
+
OptimalOrdinalThresholdingViaDynamicProgramming, ABC
|
| 224 |
+
):
|
| 225 |
+
"""Optimal 1d thresholding based on tuning thresholds to optimize the mean of a sample-wise cost function."""
|
| 226 |
+
|
| 227 |
+
@abstractmethod
|
| 228 |
+
def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
|
| 229 |
+
"""Compute the sample-wise cost of assigning label(s) `preds` when the ground truth is `labels`."""
|
| 230 |
+
|
| 231 |
+
def mean_cost(
|
| 232 |
+
self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]
|
| 233 |
+
) -> torch.Tensor:
|
| 234 |
+
"""Compute the mean cost of assigning label(s) `preds` when the ground truth is `labels`."""
|
| 235 |
+
return torch.mean(self.cost(labels=labels, preds=preds))
|
| 236 |
+
|
| 237 |
+
def dp_step(
|
| 238 |
+
self,
|
| 239 |
+
c_idx: int,
|
| 240 |
+
*,
|
| 241 |
+
scores: torch.Tensor,
|
| 242 |
+
labels: torch.Tensor,
|
| 243 |
+
available_thresholds: torch.Tensor,
|
| 244 |
+
prev_cost: Optional[torch.Tensor] = None,
|
| 245 |
+
) -> (torch.Tensor, Optional[torch.Tensor]):
|
| 246 |
+
"""O(len(scores)) implementation for per-sample cost."""
|
| 247 |
+
# Compute running_cost[i] = sum of costs of elements with score less than available_thresholds[i] if assigned label c
|
| 248 |
+
item_costs = self.cost(labels=labels, preds=c_idx) / len(scores)
|
| 249 |
+
if self.direction == "min":
|
| 250 |
+
item_costs *= -1
|
| 251 |
+
# move tensors to and from CPU because histogram has no CUDA implementation
|
| 252 |
+
cost_new_class_by_thresh, _ = torch.histogram(
|
| 253 |
+
scores.cpu().float(),
|
| 254 |
+
weight=item_costs.cpu().float(),
|
| 255 |
+
bins=available_thresholds.cpu().float(),
|
| 256 |
+
)
|
| 257 |
+
running_cost = cumsum_with_0(cost_new_class_by_thresh.to(labels.device))
|
| 258 |
+
|
| 259 |
+
# Combine with running_cost with prev_cost
|
| 260 |
+
if c_idx == 0:
|
| 261 |
+
return running_cost, None
|
| 262 |
+
diff = prev_cost - running_cost
|
| 263 |
+
cummax = torch.cummax(diff, dim=0)
|
| 264 |
+
cost = running_cost + cummax.values
|
| 265 |
+
if c_idx == self.num_classes - 1:
|
| 266 |
+
# -1 to always set the *upper* threshold for class `num_classes - 1` to include the rest of the data
|
| 267 |
+
return cost[-1], cummax.indices[-1]
|
| 268 |
+
return cost, cummax.indices
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class MaxAccuracyOrdinalThresholding(OptimalCostPerSampleOrdinalThresholding):
|
| 272 |
+
"""Threshold to maximize accuracy."""
|
| 273 |
+
|
| 274 |
+
direction = "max"
|
| 275 |
+
|
| 276 |
+
def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
|
| 277 |
+
return torch.eq(labels, preds).float()
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class MaxMacroRecallOrdinalThresholding(OptimalCostPerSampleOrdinalThresholding):
|
| 281 |
+
"""Threshold to maximize macro-averaged recall."""
|
| 282 |
+
|
| 283 |
+
direction = "max"
|
| 284 |
+
|
| 285 |
+
def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
|
| 286 |
+
counts = torch.bincount(labels, minlength=self.num_classes).float()
|
| 287 |
+
ratios = counts.sum() / (self.num_classes * counts)
|
| 288 |
+
return torch.eq(labels, preds).float() * torch.gather(
|
| 289 |
+
ratios, 0, labels.type(torch.int64)
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class MinAbsoluteErrorOrdinalThresholding(OptimalCostPerSampleOrdinalThresholding):
|
| 294 |
+
"""Threshold to minimize mean absolute error."""
|
| 295 |
+
|
| 296 |
+
direction = "min"
|
| 297 |
+
|
| 298 |
+
def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
|
| 299 |
+
return torch.abs(preds - labels).float()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class ClassWeightedOptimalCostPerSampleOrdinalThresholding(
|
| 303 |
+
OptimalCostPerSampleOrdinalThresholding
|
| 304 |
+
):
|
| 305 |
+
"""Compute cost weighted equally over classes instead of equally over samples.
|
| 306 |
+
|
| 307 |
+
This class takes another instance of OptimalCostPerSampleOrdinalThresholding
|
| 308 |
+
which computes its cost independently for each sample and reweights the cost
|
| 309 |
+
based on label frequencies.
|
| 310 |
+
|
| 311 |
+
Note: this class depends on an implementation detail of its superclass:
|
| 312 |
+
namely calling `self.cost` with the full tuning or eval set of labels,
|
| 313 |
+
rather than a single label. This is required to do the re-weighting properly.
|
| 314 |
+
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(self, unweighted_instance: OptimalCostPerSampleOrdinalThresholding):
|
| 318 |
+
self.direction = unweighted_instance.direction
|
| 319 |
+
super().__init__(unweighted_instance.num_classes)
|
| 320 |
+
self.unweighted_instance = unweighted_instance
|
| 321 |
+
|
| 322 |
+
def cost(self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]):
|
| 323 |
+
counts = torch.bincount(labels, minlength=self.num_classes)
|
| 324 |
+
(indices,) = torch.where(counts == 0)
|
| 325 |
+
if len(indices) > 0:
|
| 326 |
+
raise ValueError(
|
| 327 |
+
f"Cannot compute class-weighted cost because classes {set(indices.tolist())} are missing."
|
| 328 |
+
)
|
| 329 |
+
unweighted_cost = self.unweighted_instance.cost(labels=labels, preds=preds)
|
| 330 |
+
weights = len(labels) / (self.num_classes * counts[labels].float())
|
| 331 |
+
return weights * unweighted_cost
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class OptimalCostPerClassOrdinalThresholding(
|
| 335 |
+
OptimalOrdinalThresholdingViaDynamicProgramming, ABC
|
| 336 |
+
):
|
| 337 |
+
"""General DP case for when the linear algorithm for per-sample costs is not applicable.
|
| 338 |
+
|
| 339 |
+
Complexity depends on the implementation of `cost_matrix`.
|
| 340 |
+
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
@abstractmethod
|
| 344 |
+
def cost_matrix(
|
| 345 |
+
self,
|
| 346 |
+
c_idx: int,
|
| 347 |
+
*,
|
| 348 |
+
scores: torch.Tensor,
|
| 349 |
+
labels: torch.Tensor,
|
| 350 |
+
available_thresholds: torch.Tensor,
|
| 351 |
+
start: bool,
|
| 352 |
+
end: bool,
|
| 353 |
+
) -> torch.Tensor:
|
| 354 |
+
"""Each output[i, j] = cost for when scores in range `available_thresholds[i:j]` are assigned label `c_idx`."""
|
| 355 |
+
|
| 356 |
+
def mean_cost(
|
| 357 |
+
self, *, labels: torch.Tensor, preds: Union[int, torch.Tensor]
|
| 358 |
+
) -> torch.Tensor:
|
| 359 |
+
"""Compute the mean cost of assigning label(s) `preds` when the ground truth is `labels`."""
|
| 360 |
+
|
| 361 |
+
if isinstance(preds, int) or preds.numel() == 1:
|
| 362 |
+
preds = preds * torch.ones_like(labels, dtype=torch.int)
|
| 363 |
+
|
| 364 |
+
total_cost = torch.tensor(0.0, device=labels.device)
|
| 365 |
+
for c_idx in range(self.num_classes):
|
| 366 |
+
thresholds = torch.tensor([c_idx - 0.5, c_idx + 0.5], device=labels.device)
|
| 367 |
+
total_cost += self.cost_matrix(
|
| 368 |
+
c_idx, preds.float(), labels, thresholds, start=True, end=True
|
| 369 |
+
)[0, 0]
|
| 370 |
+
return total_cost / self.num_classes
|
| 371 |
+
|
| 372 |
+
def dp_step(
|
| 373 |
+
self,
|
| 374 |
+
c_idx: int,
|
| 375 |
+
*,
|
| 376 |
+
scores: torch.Tensor,
|
| 377 |
+
labels: torch.Tensor,
|
| 378 |
+
available_thresholds: torch.Tensor,
|
| 379 |
+
prev_cost: Optional[torch.Tensor] = None,
|
| 380 |
+
) -> (torch.Tensor, Optional[torch.Tensor]):
|
| 381 |
+
cost_matrix = (
|
| 382 |
+
self.cost_matrix(
|
| 383 |
+
c_idx,
|
| 384 |
+
scores=scores,
|
| 385 |
+
labels=labels,
|
| 386 |
+
available_thresholds=available_thresholds,
|
| 387 |
+
start=c_idx == 0,
|
| 388 |
+
end=c_idx == self.num_classes - 1,
|
| 389 |
+
)
|
| 390 |
+
/ self.num_classes
|
| 391 |
+
)
|
| 392 |
+
if self.direction == "min":
|
| 393 |
+
cost_matrix *= -1
|
| 394 |
+
if prev_cost is not None:
|
| 395 |
+
cost_matrix += prev_cost[:, None]
|
| 396 |
+
max_ = torch.max(cost_matrix, dim=0)
|
| 397 |
+
return max_.values, max_.indices
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _compute_metrics_matrices(
|
| 401 |
+
scores: torch.Tensor,
|
| 402 |
+
binary_labels: torch.Tensor,
|
| 403 |
+
thresholds: torch.Tensor,
|
| 404 |
+
start: bool = False,
|
| 405 |
+
end: bool = False,
|
| 406 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 407 |
+
"""Each output[i, j] = stats for when scores between thresholds[i] and thresolds[j] are assigned `True`.
|
| 408 |
+
|
| 409 |
+
Helper function for `MaxMacroPrecisionOrdinalThresholding` and `MaxMacroF1OrdinalThresholding`
|
| 410 |
+
|
| 411 |
+
Computed in O(len(thresholds)**2 + len(scores)*log(len(thresholds))) operations instead of the naive
|
| 412 |
+
O(len(scores)*len(thresholds)**2) operations to compute each element of the output independently.
|
| 413 |
+
|
| 414 |
+
Arguments
|
| 415 |
+
---------
|
| 416 |
+
scores (float Tensor) : scores of labeled examples for which to compute metrics
|
| 417 |
+
binary_labels (bool Tensor) : corresponding binary labels of same shape as `scores`
|
| 418 |
+
thresholds (float Tensor) : thresholds between which to compute metrics
|
| 419 |
+
start : compute only the first row of the output (lower threshold at its minimum value)
|
| 420 |
+
end : compute only the last column of the output (upper threshold at its maximum value)
|
| 421 |
+
|
| 422 |
+
Returns
|
| 423 |
+
-------
|
| 424 |
+
tp : tp[i, j] = number of true positives if scores between thresholds[i:j] are classified as positive
|
| 425 |
+
tp_plus_fp: tp_plus_fp[i, j] = number of scores between thresholds[i:j]
|
| 426 |
+
|
| 427 |
+
"""
|
| 428 |
+
# move tensors to and from CPU because histogram has no CUDA implementation
|
| 429 |
+
scores = scores.float().cpu()
|
| 430 |
+
thresholds = thresholds.float().cpu()
|
| 431 |
+
labeled_true_by_thresh, _ = torch.histogram(
|
| 432 |
+
scores,
|
| 433 |
+
weight=binary_labels.float().cpu(),
|
| 434 |
+
bins=thresholds,
|
| 435 |
+
)
|
| 436 |
+
count_by_thresh, _ = torch.histogram(
|
| 437 |
+
scores,
|
| 438 |
+
bins=thresholds,
|
| 439 |
+
)
|
| 440 |
+
running_labeled_true_by_thresh = cumsum_with_0(
|
| 441 |
+
labeled_true_by_thresh.to(binary_labels.device)
|
| 442 |
+
)
|
| 443 |
+
running_count_by_thresh = cumsum_with_0(
|
| 444 |
+
count_by_thresh.to(binary_labels.device).float()
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def start_slice(t):
|
| 448 |
+
return t[: (1 if start else None), None]
|
| 449 |
+
|
| 450 |
+
def end_slice(t):
|
| 451 |
+
return t[None, (-1 if end else None) :]
|
| 452 |
+
|
| 453 |
+
tp = end_slice(running_labeled_true_by_thresh) - start_slice(
|
| 454 |
+
running_labeled_true_by_thresh
|
| 455 |
+
)
|
| 456 |
+
tp_plus_fp = end_slice(running_count_by_thresh) - start_slice(
|
| 457 |
+
running_count_by_thresh
|
| 458 |
+
)
|
| 459 |
+
return tp, tp_plus_fp
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class MaxMacroPrecisionOrdinalThresholding(OptimalCostPerClassOrdinalThresholding):
|
| 463 |
+
"""Threshold to maximize macro-averaged precision."""
|
| 464 |
+
|
| 465 |
+
direction = "max"
|
| 466 |
+
|
| 467 |
+
def cost_matrix(
|
| 468 |
+
self,
|
| 469 |
+
c_idx: int,
|
| 470 |
+
scores: torch.Tensor,
|
| 471 |
+
labels: torch.Tensor,
|
| 472 |
+
available_thresholds: torch.Tensor,
|
| 473 |
+
start: bool,
|
| 474 |
+
end: bool,
|
| 475 |
+
) -> torch.Tensor:
|
| 476 |
+
tp, tp_plus_fp = _compute_metrics_matrices(
|
| 477 |
+
scores, torch.eq(labels, c_idx), available_thresholds, start=start, end=end
|
| 478 |
+
)
|
| 479 |
+
safe_tp_plus_fp = torch.maximum(
|
| 480 |
+
tp_plus_fp, torch.ones(1, device=tp_plus_fp.device)
|
| 481 |
+
)
|
| 482 |
+
return torch.where(torch.ge(tp_plus_fp, 0.0), tp / safe_tp_plus_fp, -torch.inf)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class MaxMacroF1OrdinalThresholding(OptimalCostPerClassOrdinalThresholding):
|
| 486 |
+
"""Threshold to maximize macro-averaged F1 score."""
|
| 487 |
+
|
| 488 |
+
direction = "max"
|
| 489 |
+
|
| 490 |
+
def cost_matrix(
|
| 491 |
+
self,
|
| 492 |
+
c_idx: int,
|
| 493 |
+
scores: torch.Tensor,
|
| 494 |
+
labels: torch.Tensor,
|
| 495 |
+
available_thresholds: torch.Tensor,
|
| 496 |
+
start: bool,
|
| 497 |
+
end: bool,
|
| 498 |
+
) -> torch.Tensor:
|
| 499 |
+
tp, tp_plus_fp = _compute_metrics_matrices(
|
| 500 |
+
scores, torch.eq(labels, c_idx), available_thresholds, start=start, end=end
|
| 501 |
+
)
|
| 502 |
+
tp_plus_fn = torch.eq(labels, c_idx).float().sum() # scalar
|
| 503 |
+
safe_tp_plus_fp = torch.maximum(
|
| 504 |
+
tp_plus_fp, torch.ones(1, device=tp_plus_fp.device)
|
| 505 |
+
)
|
| 506 |
+
return torch.where(
|
| 507 |
+
torch.ge(tp_plus_fp, 0.0),
|
| 508 |
+
2 * tp / (safe_tp_plus_fp + tp_plus_fn),
|
| 509 |
+
-torch.inf,
|
| 510 |
+
)
|