"""This module provides functions for calculating hierarchical variants of precicion, recall and F1.""" from typing import List, Dict, Tuple, Set def find_ancestors(node: str, hierarchy: Dict[str, Set[str]]) -> Set[str]: """ Find the ancestors of a given node in a hierarchy. Args: node (str): The node for which to find ancestors. hierarchy (Dict[str, Set[str]]): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents. Returns: Set[str]: A set of ancestors of the given node. """ ancestors = set() nodes_to_visit = [node] while nodes_to_visit: current_node = nodes_to_visit.pop() if current_node in hierarchy: parents = hierarchy[current_node] ancestors.update(parents) nodes_to_visit.extend(parents) return ancestors def extend_with_ancestors(classes: set, hierarchy: dict) -> set: """ Extend the given set of classes with their ancestors from the hierarchy. Args: classes (set): The set of classes to extend. hierarchy (dict): The hierarchy of classes. Returns: set: The extended set of classes including their ancestors. """ extended_classes = set(classes) for cls in classes: ancestors = find_ancestors(cls, hierarchy) extended_classes.update(ancestors) return extended_classes def calculate_hierarchical_precision_recall( reference_codes: List[str], predicted_codes: List[str], hierarchy: Dict[str, Dict[str, float]], ) -> Tuple[float, float]: """ Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition. Args: reference_codes (List[str]): The list of reference codes. predicted_codes (List[str]): The list of predicted codes. hierarchy (Dict[str, Dict[str, float]]): The hierarchy definition where keys are nodes and values are dictionaries of parent nodes with distances. Returns: Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values. """ extended_real = {} extended_predicted = {} # Extend the sets of reference codes with their ancestors for code in reference_codes: extended_real[code] = 1.0 # Full weight for exact match for ancestor, ancestor_weight in hierarchy.get(code, {}).items(): extended_real[ancestor] = max( extended_real.get(ancestor, 0), ancestor_weight ) # Extend the sets of predicted codes with their ancestors for code in predicted_codes: extended_predicted[code] = 1.0 for ancestor, ancestor_weight in hierarchy.get(code, {}).items(): extended_predicted[ancestor] = max( extended_predicted.get(ancestor, 0), ancestor_weight ) # Calculate weighted correct predictions for precision correct_weights_precision = 0 for code, weight in extended_predicted.items(): if code in extended_real: correct_weights_precision += min(weight, extended_real[code]) # Calculate weighted correct predictions for recall correct_weights_recall = 0 for code, weight in extended_real.items(): if code in extended_predicted: correct_weights_recall += min(weight, extended_predicted[code]) total_predicted_weights = sum(extended_predicted.values()) total_real_weights = sum(extended_real.values()) # Calculate hierarchical precision and recall using weighted sums hP = ( correct_weights_precision / total_predicted_weights if total_predicted_weights else 0 ) hR = correct_weights_recall / total_real_weights if total_real_weights else 0 return hP, hR def hierarchical_f_measure(hP, hR, beta=1.0): """ Calculate the hierarchical F-measure. Parameters: hP (float): The hierarchical precision. hR (float): The hierarchical recall. beta (float, optional): The beta value for F-measure calculation. Default is 1.0. Returns: float: The hierarchical F-measure. """ if hP + hR == 0: return 0 return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) # Example list usage: # reference_codes = ["1111", "1112", "1113", "1114"] # predicted_codes = ["1111", "1113", "1120", "1211"] # hierarchy_dict = {'1111': {'111', '1', '11'}, '1112': {'111', '1', '11'}, '1113': {'111', '1', '11'}, '1114': {'111', '1', '11'} ...} # result = calculate_hierarchical_precision_recall(real_codes, predicted_codes, hierarchy_dict) # print(result)