danieldux's picture
Add docstring to hierarchical_f_measure function
d1fbaa3
raw
history blame
3.24 kB
"""This module provides functions for calculating hierarchical precicion, recall and f1."""
from typing import List, Set, Dict, Tuple
def find_ancestors(node: str, hierarchy: dict) -> set:
"""
Find the ancestors of a given node in a hierarchy.
Args:
node (str): The node for which to find ancestors.
hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
Returns:
set: A set of ancestors of the given node.
"""
ancestors = set()
nodes_to_visit = [node]
while nodes_to_visit:
current_node = nodes_to_visit.pop()
if current_node in hierarchy:
parents = hierarchy[current_node]
ancestors.update(parents)
nodes_to_visit.extend(parents)
return ancestors
def extend_with_ancestors(classes: set, hierarchy: dict) -> set:
"""
Extend the given set of classes with their ancestors from the hierarchy.
Args:
classes (set): The set of classes to extend.
hierarchy (dict): The hierarchy of classes.
Returns:
set: The extended set of classes including their ancestors.
"""
extended_classes = set(classes)
for cls in classes:
ancestors = find_ancestors(cls, hierarchy)
extended_classes.update(ancestors)
return extended_classes
def calculate_hierarchical_precision_recall(
reference_codes: List[str],
predicted_codes: List[str],
hierarchy: Dict[str, Set[str]],
) -> Tuple[float, float]:
"""
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
Args:
real_codes (List[str]): The list of reference codes.
predicted_codes (List[str]): The list of predicted codes.
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
Returns:
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
"""
# Extend the sets of real and predicted codes with their ancestors
extended_real = set()
for code in reference_codes:
extended_real.add(code)
extended_real.update(hierarchy.get(code, set()))
extended_predicted = set()
for code in predicted_codes:
extended_predicted.add(code)
extended_predicted.update(hierarchy.get(code, set()))
# Calculate the intersection
correct_predictions = extended_real.intersection(extended_predicted)
# Calculate hierarchical precision and recall
hP = len(correct_predictions) / len(extended_predicted) if extended_predicted else 0
hR = len(correct_predictions) / len(extended_real) if extended_real else 0
return hP, hR
def hierarchical_f_measure(hP, hR, beta=1.0):
"""
Calculate the hierarchical F-measure.
Parameters:
hP (float): The hierarchical precision.
hR (float): The hierarchical recall.
beta (float, optional): The beta value for F-measure calculation. Default is 1.0.
Returns:
float: The hierarchical F-measure.
"""
if hP + hR == 0:
return 0
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)