"""Label utilities for CodeBERT multi-label classification.""" from __future__ import annotations from pathlib import Path from typing import Dict, List, Sequence, Union import numpy as np import yaml CONFIG_PATH = ( Path(__file__).resolve().parent.parent / "config" / "codebert_labels.yaml" ) def load_codebert_labels(config_path: Path = CONFIG_PATH) -> List[str]: with open(config_path) as f: data = yaml.safe_load(f) return list(data["labels"]) def load_alias_map(config_path: Path = CONFIG_PATH) -> Dict[str, List[str]]: with open(config_path) as f: data = yaml.safe_load(f) return {k: list(v) for k, v in data["alias_map"].items()} def label_to_multihot( error_labels: Union[str, Sequence[str]], label_list: List[str] | None = None, alias_map: Dict[str, List[str]] | None = None, ) -> np.ndarray: """ Convert error label(s) to multi-hot vector. Accepts: - comma-separated string: "JOIN_ERROR,AGGREGATION_ERROR" - list of label strings - single dataset label_name (resolved via alias_map) """ labels = label_list or load_codebert_labels() aliases = alias_map or load_alias_map() index = {name: i for i, name in enumerate(labels)} vec = np.zeros(len(labels), dtype=np.float32) if isinstance(error_labels, str): raw = [s.strip() for s in error_labels.split(",") if s.strip()] if len(raw) == 1 and raw[0] in aliases: raw = aliases[raw[0]] elif len(raw) == 1 and raw[0] in index: raw = [raw[0]] elif len(raw) == 1 and raw[0] not in index: mapped = aliases.get(raw[0], []) raw = mapped else: raw = list(error_labels) expanded: List[str] = [] for item in raw: if item in aliases: expanded.extend(aliases[item]) elif item in index: expanded.append(item) raw = expanded for name in raw: if name not in index: raise ValueError(f"Unknown label '{name}'. Expected one of {labels}") vec[index[name]] = 1.0 if vec.sum() == 0: raise ValueError(f"No valid labels found in {error_labels}") return vec def multihot_to_label_names( vec: np.ndarray, label_list: List[str] | None = None, threshold: float = 0.5, ) -> List[str]: labels = label_list or load_codebert_labels() indices = np.where(vec >= threshold)[0] return [labels[i] for i in indices]