Sky-Blue-da-ba-dee's picture
added files
ac9ddbb
"""Preprocessing helpers for transformer training.
This module provides utilities to parse multi-label strings, ensure the
`combo` column exists, perform label-aware supersampling of a training
DataFrame, and a light-weight `load_or_prepare_data` entrypoint that loads
raw CSVs, optionally applies preprocessing, and writes processed CSVs.
"""
import logging
import os
from typing import Tuple
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
def parse_label_str(s: str) -> np.ndarray:
"""Convert a string like '[0 0 1 0 0 0 0]' into a float32 numpy array."""
return np.fromstring(str(s).strip("[]"), sep=" ", dtype=np.float32)
def ensure_combo_column(df: pd.DataFrame) -> pd.DataFrame:
"""Ensure that the 'combo' column exists.
If missing, create it from 'comment_sentence' and 'class'.
"""
if "combo" not in df.columns:
logger.info("Column 'combo' not found, creating it from 'comment_sentence' and 'class'.")
df = df.copy()
df["combo"] = df["comment_sentence"].astype(str) + " | " + df["class"].astype(str)
else:
logger.info("Column 'combo' already present, reusing it.")
return df
def supersample_dataframe(
df: pd.DataFrame,
factor: float,
random_state: int = 42,
) -> pd.DataFrame:
"""Offline label-aware supersampling of the training DataFrame.
- Keeps all original rows.
- For each label j, duplicates rows that contain that label until:
target_j = min(max_freq, freq_j * factor)
where freq_j is the original count for label j and max_freq is the
maximum frequency across labels.
- Shuffles the resulting indices.
Assumes:
- df['labels'] is a string representation of a multi-hot vector.
"""
if factor <= 1.0:
logger.info(
"Supersampling factor <= 1.0 (%.2f), returning original DataFrame.",
factor,
)
return df.copy()
rng = np.random.default_rng(random_state)
labels_array = np.stack(df["labels"].map(parse_label_str).values)
if labels_array.ndim == 1:
labels_array = labels_array[:, None]
num_samples, num_labels = labels_array.shape
freq = labels_array.sum(axis=0).astype(int)
max_freq = int(freq.max())
logger.info("Original label frequencies: %s", freq.tolist())
logger.info("Max label frequency: %d", max_freq)
if max_freq == 0:
logger.warning("All label frequencies are zero, skipping supersampling.")
return df.copy()
target = np.minimum(max_freq, (freq * factor).astype(int))
logger.info(
"Target label frequencies after supersampling (capped by max_freq): %s",
target.tolist(),
)
indices_by_label = {j: np.where(labels_array[:, j] == 1)[0] for j in range(num_labels)}
new_indices = list(range(num_samples))
for j in range(num_labels):
current = int(freq[j])
desired = int(target[j])
if desired <= current:
continue
candidate_indices = indices_by_label[j]
if candidate_indices.size == 0:
continue
needed = desired - current
extra = rng.choice(candidate_indices, size=needed, replace=True)
new_indices.extend(extra.tolist())
logger.info(
"Label %d: current=%d, target=%d, added=%d samples.",
j,
current,
desired,
needed,
)
rng.shuffle(new_indices)
df_sup = df.iloc[new_indices].reset_index(drop=True)
labels_array_after = np.stack(df_sup["labels"].map(parse_label_str).values)
freq_after = labels_array_after.sum(axis=0).astype(int)
logger.info("Final label frequencies after supersampling: %s", freq_after.tolist())
logger.info("Training rows before: %d, after: %d", num_samples, len(df_sup))
return df_sup
def load_or_prepare_data(
lang: str,
raw_data_dir: str,
processed_data_dir: str,
preprocessing_enabled: bool,
preprocessing_factor: float,
random_state: int = 42,
) -> Tuple[pd.DataFrame, pd.DataFrame, str]:
"""Load raw CSVs for the given language, optionally apply preprocessing.
(supersampling) on the train split, and save processed CSVs.
- Test split is NEVER supersampled or augmented.
- Train split:
- always gets 'combo' and 'labels_array'
- supersampled only if preprocessing_enabled=True and preprocessing_factor>1.0
Parameters
----------
lang : str
Language key (e.g., 'java', 'python', 'pharo').
raw_data_dir : str
Directory containing {lang}_train.csv and {lang}_test.csv.
processed_data_dir : str
Directory where processed CSVs will be saved.
preprocessing_enabled : bool
Whether to apply supersampling on the training split.
preprocessing_factor : float
Supersampling factor (ignored if preprocessing_enabled=False).
random_state : int
RNG seed.
Returns
-------
train_df : pd.DataFrame
eval_df : pd.DataFrame
preprocessing_used : str
One of: 'none', 'supersampling'.
"""
logger.info("Loading raw CSVs for language '%s' from '%s'.", lang, raw_data_dir)
raw_train_path = os.path.join(raw_data_dir, f"{lang}_train.csv")
raw_eval_path = os.path.join(raw_data_dir, f"{lang}_test.csv")
if not os.path.exists(raw_train_path):
raise FileNotFoundError(f"Raw train CSV not found: {raw_train_path}")
if not os.path.exists(raw_eval_path):
raise FileNotFoundError(f"Raw test CSV not found: {raw_eval_path}")
train_df = pd.read_csv(raw_train_path)
eval_df = pd.read_csv(raw_eval_path)
train_df = ensure_combo_column(train_df)
eval_df = ensure_combo_column(eval_df)
if preprocessing_enabled and preprocessing_factor > 1.0:
logger.info(
"Preprocessing enabled: applying supersampling with factor=%.2f.",
preprocessing_factor,
)
train_df = supersample_dataframe(
train_df,
factor=preprocessing_factor,
random_state=random_state,
)
preprocessing_used = "supersampling"
else:
logger.info(
"Preprocessing disabled or factor <= 1.0 (%.2f). Using original training data.",
preprocessing_factor,
)
preprocessing_used = "none"
# Save processed CSVs (for inspection / reproducibility)
os.makedirs(processed_data_dir, exist_ok=True)
processed_train_path = os.path.join(processed_data_dir, f"{lang}_train.csv")
processed_eval_path = os.path.join(processed_data_dir, f"{lang}_test.csv")
train_df.to_csv(processed_train_path, index=False)
eval_df.to_csv(processed_eval_path, index=False)
logger.info("Saved processed train/test CSVs to '%s'.", processed_data_dir)
# Ensure 'labels_array' exists for both splits
for df, split_name in ((train_df, "train"), (eval_df, "test")):
if "labels_array" not in df.columns:
logger.info("Parsing label strings into arrays for split '%s'.", split_name)
df["labels_array"] = df["labels"].apply(parse_label_str)
return train_df, eval_df, preprocessing_used