Spaces:
Running
Running
| """Label utilities for CodeBERT multi-label classification.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Dict, List, Sequence, Union | |
| import numpy as np | |
| import yaml | |
| CONFIG_PATH = ( | |
| Path(__file__).resolve().parent.parent / "config" / "codebert_labels.yaml" | |
| ) | |
| def load_codebert_labels(config_path: Path = CONFIG_PATH) -> List[str]: | |
| with open(config_path) as f: | |
| data = yaml.safe_load(f) | |
| return list(data["labels"]) | |
| def load_alias_map(config_path: Path = CONFIG_PATH) -> Dict[str, List[str]]: | |
| with open(config_path) as f: | |
| data = yaml.safe_load(f) | |
| return {k: list(v) for k, v in data["alias_map"].items()} | |
| def label_to_multihot( | |
| error_labels: Union[str, Sequence[str]], | |
| label_list: List[str] | None = None, | |
| alias_map: Dict[str, List[str]] | None = None, | |
| ) -> np.ndarray: | |
| """ | |
| Convert error label(s) to multi-hot vector. | |
| Accepts: | |
| - comma-separated string: "JOIN_ERROR,AGGREGATION_ERROR" | |
| - list of label strings | |
| - single dataset label_name (resolved via alias_map) | |
| """ | |
| labels = label_list or load_codebert_labels() | |
| aliases = alias_map or load_alias_map() | |
| index = {name: i for i, name in enumerate(labels)} | |
| vec = np.zeros(len(labels), dtype=np.float32) | |
| if isinstance(error_labels, str): | |
| raw = [s.strip() for s in error_labels.split(",") if s.strip()] | |
| if len(raw) == 1 and raw[0] in aliases: | |
| raw = aliases[raw[0]] | |
| elif len(raw) == 1 and raw[0] in index: | |
| raw = [raw[0]] | |
| elif len(raw) == 1 and raw[0] not in index: | |
| mapped = aliases.get(raw[0], []) | |
| raw = mapped | |
| else: | |
| raw = list(error_labels) | |
| expanded: List[str] = [] | |
| for item in raw: | |
| if item in aliases: | |
| expanded.extend(aliases[item]) | |
| elif item in index: | |
| expanded.append(item) | |
| raw = expanded | |
| for name in raw: | |
| if name not in index: | |
| raise ValueError(f"Unknown label '{name}'. Expected one of {labels}") | |
| vec[index[name]] = 1.0 | |
| if vec.sum() == 0: | |
| raise ValueError(f"No valid labels found in {error_labels}") | |
| return vec | |
| def multihot_to_label_names( | |
| vec: np.ndarray, | |
| label_list: List[str] | None = None, | |
| threshold: float = 0.5, | |
| ) -> List[str]: | |
| labels = label_list or load_codebert_labels() | |
| indices = np.where(vec >= threshold)[0] | |
| return [labels[i] for i in indices] | |