danieldux commited on
Commit
e4caed1
·
1 Parent(s): 08d668a

Refactor ISCO_Hierarchical_Accuracy fix recall

Browse files
Files changed (1) hide show
  1. isco_hierarchical_accuracy.py +28 -30
isco_hierarchical_accuracy.py CHANGED
@@ -154,16 +154,16 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
154
 
155
  return isco_hierarchy
156
 
157
- def find_ancestors(self, node: str, hierarchy: dict) -> set:
158
  """
159
  Find the ancestors of a given node in a hierarchy.
160
 
161
  Args:
162
  node (str): The node for which to find ancestors.
163
- hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
164
 
165
  Returns:
166
- set: A set of ancestors of the given node.
167
  """
168
  ancestors = set()
169
  nodes_to_visit = [node]
@@ -204,45 +204,43 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
204
  Args:
205
  reference_codes (List[str]): The list of reference codes.
206
  predicted_codes (List[str]): The list of predicted codes.
207
- hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
208
 
209
  Returns:
210
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
211
  """
212
- extended_real = {}
 
213
 
214
  # Extend the sets of reference codes with their ancestors
215
  for code in reference_codes:
216
- weight = 1.0 # Full weight for exact match
217
- extended_real[code] = weight
218
- for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
219
- extended_real[ancestor] = max(
220
- extended_real.get(ancestor, 0), ancestor_weight
221
- )
222
-
223
- extended_predicted = {}
224
 
225
  # Extend the sets of predicted codes with their ancestors
226
  for code in predicted_codes:
227
- weight = 1.0
228
- extended_predicted[code] = weight
229
- for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
230
- extended_predicted[ancestor] = max(
231
- extended_predicted.get(ancestor, 0), ancestor_weight
232
- )
233
-
234
- # Calculate weighted correct predictions
235
- correct_weights = 0
236
- for code, weight in extended_predicted.items():
237
- if code in extended_real:
238
- correct_weights += min(weight, extended_real[code])
239
 
240
- total_predicted_weights = sum(extended_predicted.values())
241
- total_real_weights = sum(extended_real.values())
242
 
243
- # Calculate hierarchical precision and recall using weighted sums
244
- hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
245
- hR = correct_weights / total_real_weights if total_real_weights else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  return hP, hR
248
 
 
154
 
155
  return isco_hierarchy
156
 
157
+ def find_ancestors(self, node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]:
158
  """
159
  Find the ancestors of a given node in a hierarchy.
160
 
161
  Args:
162
  node (str): The node for which to find ancestors.
163
+ hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
164
 
165
  Returns:
166
+ Set[str]: A set of ancestors of the given node.
167
  """
168
  ancestors = set()
169
  nodes_to_visit = [node]
 
204
  Args:
205
  reference_codes (List[str]): The list of reference codes.
206
  predicted_codes (List[str]): The list of predicted codes.
207
+ hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances.
208
 
209
  Returns:
210
  Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
211
  """
212
+ extended_real = set()
213
+ extended_predicted = set()
214
 
215
  # Extend the sets of reference codes with their ancestors
216
  for code in reference_codes:
217
+ extended_real.add(code)
218
+ extended_real.update(self.find_ancestors(code, hierarchy))
 
 
 
 
 
 
219
 
220
  # Extend the sets of predicted codes with their ancestors
221
  for code in predicted_codes:
222
+ extended_predicted.add(code)
223
+ extended_predicted.update(self.find_ancestors(code, hierarchy))
 
 
 
 
 
 
 
 
 
 
224
 
225
+ # Calculate the intersection for recall
226
+ correct_recall = extended_real.intersection(extended_predicted)
227
 
228
+ # Calculate the intersection for precision
229
+ correct_precision = set()
230
+ for code in predicted_codes:
231
+ if code in extended_real:
232
+ correct_precision.add(code)
233
+ correct_precision.update(
234
+ self.find_ancestors(code, hierarchy).intersection(extended_real)
235
+ )
236
+
237
+ # Calculate hierarchical precision and recall using the size of intersections
238
+ hP = (
239
+ len(correct_precision) / len(extended_predicted)
240
+ if extended_predicted
241
+ else 0
242
+ )
243
+ hR = len(correct_recall) / len(extended_real) if extended_real else 0
244
 
245
  return hP, hR
246