sql-error-classifier / src /codebert_labels.py
nishu08's picture
Deploy CodeBERT inference Space
8a3099e verified
"""Label utilities for CodeBERT multi-label classification."""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Sequence, Union
import numpy as np
import yaml
CONFIG_PATH = (
Path(__file__).resolve().parent.parent / "config" / "codebert_labels.yaml"
)
def load_codebert_labels(config_path: Path = CONFIG_PATH) -> List[str]:
with open(config_path) as f:
data = yaml.safe_load(f)
return list(data["labels"])
def load_alias_map(config_path: Path = CONFIG_PATH) -> Dict[str, List[str]]:
with open(config_path) as f:
data = yaml.safe_load(f)
return {k: list(v) for k, v in data["alias_map"].items()}
def label_to_multihot(
error_labels: Union[str, Sequence[str]],
label_list: List[str] | None = None,
alias_map: Dict[str, List[str]] | None = None,
) -> np.ndarray:
"""
Convert error label(s) to multi-hot vector.
Accepts:
- comma-separated string: "JOIN_ERROR,AGGREGATION_ERROR"
- list of label strings
- single dataset label_name (resolved via alias_map)
"""
labels = label_list or load_codebert_labels()
aliases = alias_map or load_alias_map()
index = {name: i for i, name in enumerate(labels)}
vec = np.zeros(len(labels), dtype=np.float32)
if isinstance(error_labels, str):
raw = [s.strip() for s in error_labels.split(",") if s.strip()]
if len(raw) == 1 and raw[0] in aliases:
raw = aliases[raw[0]]
elif len(raw) == 1 and raw[0] in index:
raw = [raw[0]]
elif len(raw) == 1 and raw[0] not in index:
mapped = aliases.get(raw[0], [])
raw = mapped
else:
raw = list(error_labels)
expanded: List[str] = []
for item in raw:
if item in aliases:
expanded.extend(aliases[item])
elif item in index:
expanded.append(item)
raw = expanded
for name in raw:
if name not in index:
raise ValueError(f"Unknown label '{name}'. Expected one of {labels}")
vec[index[name]] = 1.0
if vec.sum() == 0:
raise ValueError(f"No valid labels found in {error_labels}")
return vec
def multihot_to_label_names(
vec: np.ndarray,
label_list: List[str] | None = None,
threshold: float = 0.5,
) -> List[str]:
labels = label_list or load_codebert_labels()
indices = np.where(vec >= threshold)[0]
return [labels[i] for i in indices]