xshubhamx commited on
Commit
e9f5f2c
Β·
verified Β·
1 Parent(s): 7529c15

Update multiclass_sensitivity_macro.py

Browse files
Files changed (1) hide show
  1. multiclass_sensitivity_macro.py +2 -40
multiclass_sensitivity_macro.py CHANGED
@@ -90,45 +90,7 @@ class multiclass_sensitivity_macro(evaluate.Metric):
90
  """Returns the scores"""
91
  # TODO: Compute the different scores of the module
92
 
93
- from collections import defaultdict
94
- """
95
- Calculate multiclass sensitivity (recall) for each class,
96
- as well as weighted and macro averages.
97
-
98
- Args:
99
- references (list): List of true class labels.
100
- predictions (list): List of predicted class labels.
101
-
102
- Returns:
103
- tuple: Class-wise sensitivity, weighted average sensitivity, macro average sensitivity.
104
- """
105
- # Count true positives, false negatives, and true instance counts for each class
106
- tp_counts = defaultdict(int)
107
- fn_counts = defaultdict(int)
108
- true_counts = defaultdict(int)
109
-
110
- for true_label, pred_label in zip(references, predictions):
111
- true_counts[true_label] += 1
112
- if true_label == pred_label:
113
- tp_counts[true_label] += 1
114
- else:
115
- fn_counts[true_label] += 1
116
-
117
- # Calculate class-wise sensitivity
118
- class_sensitivities = {}
119
- total_weight = sum(true_counts.values())
120
- weighted_sum = 0.0
121
-
122
- for class_label in set(references):
123
- tp = tp_counts[class_label]
124
- fn = fn_counts[class_label]
125
- true_instances = true_counts[class_label]
126
-
127
- sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
128
- class_sensitivities[class_label] = sensitivity
129
- weighted_sum += sensitivity * true_instances
130
-
131
- macro_avg_sensitivity = sum(class_sensitivities.values()) / len(class_sensitivities) if class_sensitivities else 0
132
  return {
133
- "macro_sensitivity": macro_avg_sensitivity,
134
  }
 
90
  """Returns the scores"""
91
  # TODO: Compute the different scores of the module
92
 
93
+ from imblearn.metrics import sensitivity_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  return {
95
+ "macro_sensitivity": sensitivity_score(references, predictions, average = "macro"),
96
  }