Spaces:
Build error
Build error
Refactor ISCO_Hierarchical_Accuracy fix recall
Browse files- isco_hierarchical_accuracy.py +28 -30
isco_hierarchical_accuracy.py
CHANGED
@@ -154,16 +154,16 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
154 |
|
155 |
return isco_hierarchy
|
156 |
|
157 |
-
def find_ancestors(self, node: str, hierarchy:
|
158 |
"""
|
159 |
Find the ancestors of a given node in a hierarchy.
|
160 |
|
161 |
Args:
|
162 |
node (str): The node for which to find ancestors.
|
163 |
-
hierarchy (
|
164 |
|
165 |
Returns:
|
166 |
-
|
167 |
"""
|
168 |
ancestors = set()
|
169 |
nodes_to_visit = [node]
|
@@ -204,45 +204,43 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
204 |
Args:
|
205 |
reference_codes (List[str]): The list of reference codes.
|
206 |
predicted_codes (List[str]): The list of predicted codes.
|
207 |
-
hierarchy (Dict[str,
|
208 |
|
209 |
Returns:
|
210 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
211 |
"""
|
212 |
-
extended_real =
|
|
|
213 |
|
214 |
# Extend the sets of reference codes with their ancestors
|
215 |
for code in reference_codes:
|
216 |
-
|
217 |
-
extended_real
|
218 |
-
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
219 |
-
extended_real[ancestor] = max(
|
220 |
-
extended_real.get(ancestor, 0), ancestor_weight
|
221 |
-
)
|
222 |
-
|
223 |
-
extended_predicted = {}
|
224 |
|
225 |
# Extend the sets of predicted codes with their ancestors
|
226 |
for code in predicted_codes:
|
227 |
-
|
228 |
-
extended_predicted
|
229 |
-
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
230 |
-
extended_predicted[ancestor] = max(
|
231 |
-
extended_predicted.get(ancestor, 0), ancestor_weight
|
232 |
-
)
|
233 |
-
|
234 |
-
# Calculate weighted correct predictions
|
235 |
-
correct_weights = 0
|
236 |
-
for code, weight in extended_predicted.items():
|
237 |
-
if code in extended_real:
|
238 |
-
correct_weights += min(weight, extended_real[code])
|
239 |
|
240 |
-
|
241 |
-
|
242 |
|
243 |
-
# Calculate
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
return hP, hR
|
248 |
|
|
|
154 |
|
155 |
return isco_hierarchy
|
156 |
|
157 |
+
def find_ancestors(self, node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]:
|
158 |
"""
|
159 |
Find the ancestors of a given node in a hierarchy.
|
160 |
|
161 |
Args:
|
162 |
node (str): The node for which to find ancestors.
|
163 |
+
hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
|
164 |
|
165 |
Returns:
|
166 |
+
Set[str]: A set of ancestors of the given node.
|
167 |
"""
|
168 |
ancestors = set()
|
169 |
nodes_to_visit = [node]
|
|
|
204 |
Args:
|
205 |
reference_codes (List[str]): The list of reference codes.
|
206 |
predicted_codes (List[str]): The list of predicted codes.
|
207 |
+
hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances.
|
208 |
|
209 |
Returns:
|
210 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
211 |
"""
|
212 |
+
extended_real = set()
|
213 |
+
extended_predicted = set()
|
214 |
|
215 |
# Extend the sets of reference codes with their ancestors
|
216 |
for code in reference_codes:
|
217 |
+
extended_real.add(code)
|
218 |
+
extended_real.update(self.find_ancestors(code, hierarchy))
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
# Extend the sets of predicted codes with their ancestors
|
221 |
for code in predicted_codes:
|
222 |
+
extended_predicted.add(code)
|
223 |
+
extended_predicted.update(self.find_ancestors(code, hierarchy))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
+
# Calculate the intersection for recall
|
226 |
+
correct_recall = extended_real.intersection(extended_predicted)
|
227 |
|
228 |
+
# Calculate the intersection for precision
|
229 |
+
correct_precision = set()
|
230 |
+
for code in predicted_codes:
|
231 |
+
if code in extended_real:
|
232 |
+
correct_precision.add(code)
|
233 |
+
correct_precision.update(
|
234 |
+
self.find_ancestors(code, hierarchy).intersection(extended_real)
|
235 |
+
)
|
236 |
+
|
237 |
+
# Calculate hierarchical precision and recall using the size of intersections
|
238 |
+
hP = (
|
239 |
+
len(correct_precision) / len(extended_predicted)
|
240 |
+
if extended_predicted
|
241 |
+
else 0
|
242 |
+
)
|
243 |
+
hR = len(correct_recall) / len(extended_real) if extended_real else 0
|
244 |
|
245 |
return hP, hR
|
246 |
|