| | |
| | """Evaluate a saved `.pt` checkpoint on the validation split (optionally using a frozen protocol). |
| | |
| | Outputs: |
| | - Predictions CSV (for Streamlit Evaluation dashboard): columns `sample_id`, `class_0..`, `target_class_0..` |
| | - Metrics JSON (for model zoo + dashboards), including optional optimized global threshold. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import hashlib |
| | import json |
| | import logging |
| | import sys |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import pandas as pd |
| | import torch |
| |
|
| | PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| | sys.path.insert(0, str(PROJECT_ROOT)) |
| |
|
| | from data.data_loader import load_data, split_data |
| | from data.transformer_dataset import TransformerNewsDataset |
| | from models.transformer_model import RussianNewsClassifier |
| | from utils.data_processing import create_target_encoding, process_tags |
| | from utils.text_processing import normalise_text |
| | from utils.tokenization import create_tokenizer |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _file_sha256(path: str | Path, chunk_size: int = 1024 * 1024) -> str: |
| | p = Path(path) |
| | h = hashlib.sha256() |
| | with p.open("rb") as f: |
| | while True: |
| | chunk = f.read(chunk_size) |
| | if not chunk: |
| | break |
| | h.update(chunk) |
| | return h.hexdigest() |
| |
|
| |
|
| | def _pick_device() -> torch.device: |
| | if torch.cuda.is_available(): |
| | return torch.device("cuda") |
| | if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): |
| | return torch.device("mps") |
| | return torch.device("cpu") |
| |
|
| |
|
| | def _metrics_from_binary(target: torch.Tensor, pred: torch.Tensor) -> dict[str, float]: |
| | """ |
| | Compute the same family of metrics used in existing `experiments/results/*.json`. |
| | - precision/recall/f1 are averaged per-sample (like `evaluation.metrics`) |
| | - exact_match is elementwise accuracy across all labels |
| | - subset_accuracy is strict set match per sample |
| | - micro_* are computed globally across all labels |
| | """ |
| | target = target.float() |
| | pred = pred.float() |
| |
|
| | |
| | tp_per = ((pred == 1) & (target == 1)).sum(dim=1).float() |
| | pred_pos_per = (pred == 1).sum(dim=1).float() |
| | true_pos_per = (target == 1).sum(dim=1).float() |
| |
|
| | precision = (tp_per / (pred_pos_per + 1e-5)).mean().item() |
| | recall = (tp_per / (true_pos_per + 1e-5)).mean().item() |
| | f1 = (2 * precision * recall) / (precision + recall + 1e-5) |
| |
|
| | exact_match = (pred == target).float().mean().item() |
| | subset_accuracy = (pred == target).all(dim=1).float().mean().item() |
| |
|
| | tp = ((pred == 1) & (target == 1)).sum().float() |
| | fp = ((pred == 1) & (target == 0)).sum().float() |
| | fn = ((pred == 0) & (target == 1)).sum().float() |
| |
|
| | micro_precision = (tp / (tp + fp + 1e-5)).item() |
| | micro_recall = (tp / (tp + fn + 1e-5)).item() |
| | micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5) |
| |
|
| | return { |
| | "precision": float(precision), |
| | "recall": float(recall), |
| | "f1": float(f1), |
| | "exact_match": float(exact_match), |
| | "subset_accuracy": float(subset_accuracy), |
| | "micro_precision": float(micro_precision), |
| | "micro_recall": float(micro_recall), |
| | "micro_f1": float(micro_f1), |
| | } |
| |
|
| |
|
| | @torch.inference_mode() |
| | def _predict_probs( |
| | *, |
| | model: RussianNewsClassifier, |
| | dataset: TransformerNewsDataset, |
| | batch_size: int, |
| | device: torch.device, |
| | ) -> tuple[torch.Tensor, torch.Tensor, list[str]]: |
| | """Return (probs, targets, sample_ids).""" |
| | model.eval() |
| | model.to(device) |
| |
|
| | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
| |
|
| | probs_list: list[torch.Tensor] = [] |
| | targets_list: list[torch.Tensor] = [] |
| | sample_ids: list[str] = [] |
| |
|
| | |
| | if "href" in dataset.df.columns: |
| | ids = dataset.df["href"].astype(str).tolist() |
| | else: |
| | ids = dataset.df.index.astype(str).tolist() |
| |
|
| | offset = 0 |
| | for batch in loader: |
| | bsz = batch["labels"].shape[0] |
| | sample_ids.extend(ids[offset : offset + bsz]) |
| | offset += bsz |
| |
|
| | batch_device: dict[str, torch.Tensor] = {} |
| | for k, v in batch.items(): |
| | if isinstance(v, torch.Tensor): |
| | batch_device[k] = v.to(device) |
| | logits = model( |
| | title_input_ids=batch_device["title_input_ids"], |
| | title_attention_mask=batch_device["title_attention_mask"], |
| | snippet_input_ids=batch_device.get("snippet_input_ids"), |
| | snippet_attention_mask=batch_device.get("snippet_attention_mask"), |
| | ) |
| | probs = torch.sigmoid(logits).detach().cpu() |
| | probs_list.append(probs) |
| | targets_list.append(batch["labels"].detach().cpu()) |
| |
|
| | probs_all = torch.cat(probs_list, dim=0) if probs_list else torch.empty((0, 0)) |
| | targets_all = torch.cat(targets_list, dim=0) if targets_list else torch.empty((0, 0)) |
| | return probs_all, targets_all, sample_ids |
| |
|
| |
|
| | def _optimize_threshold( |
| | *, |
| | probs: torch.Tensor, |
| | target: torch.Tensor, |
| | metric: str, |
| | min_t: float = 0.01, |
| | max_t: float = 0.99, |
| | step: float = 0.01, |
| | ) -> tuple[float, dict[str, float]]: |
| | if probs.numel() == 0: |
| | return 0.5, _metrics_from_binary(target, probs) |
| |
|
| | if metric not in {"precision", "recall", "f1"}: |
| | raise ValueError(f"Unknown optimize metric: {metric}") |
| |
|
| | best_t = 0.5 |
| | best_val = -1.0 |
| | best_metrics: dict[str, float] = {} |
| |
|
| | t = min_t |
| | while t <= max_t + 1e-9: |
| | pred = (probs >= t).float() |
| | m = _metrics_from_binary(target, pred) |
| | score = m[metric] |
| | if score > best_val: |
| | best_val = score |
| | best_t = float(round(t, 2)) |
| | best_metrics = m |
| | t = round(t + step, 10) |
| |
|
| | return best_t, best_metrics |
| |
|
| |
|
| | def main() -> int: |
| | parser = argparse.ArgumentParser(description="Evaluate a trained model checkpoint") |
| | parser.add_argument("--checkpoint", type=str, required=True, help="Path to saved `.pt` checkpoint") |
| | parser.add_argument("--data-path", type=str, default="data/news_data/ria_news.tsv", help="Path to RIA TSV") |
| | parser.add_argument("--protocol-dir", type=str, default=None, help="Frozen protocol directory (splits.json + tag_to_idx.json)") |
| | parser.add_argument("--max-val-samples", type=int, default=None, help="Limit validation samples (ignored if protocol-dir is set)") |
| | parser.add_argument("--threshold", type=float, default=0.5, help="Default global threshold for reporting `metrics`") |
| |
|
| | parser.add_argument("--optimize-threshold", action="store_true", help="Search for best global threshold on val set") |
| | parser.add_argument( |
| | "--optimize-metric", |
| | type=str, |
| | default="f1", |
| | choices=["precision", "recall", "f1"], |
| | help="Metric to optimize when --optimize-threshold is set", |
| | ) |
| |
|
| | parser.add_argument("--batch-size", type=int, default=16, help="Eval batch size") |
| | parser.add_argument("--model-id", type=str, default=None, help="Optional model identifier (defaults to checkpoint stem)") |
| | parser.add_argument("--output-csv", type=str, default=None, help="Write predictions CSV to this path") |
| | parser.add_argument("--metrics-json", type=str, default=None, help="Write metrics JSON to this path") |
| | args = parser.parse_args() |
| |
|
| | ckpt_path = Path(args.checkpoint) |
| | if not ckpt_path.exists(): |
| | raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") |
| |
|
| | checkpoint: dict[str, Any] = torch.load(ckpt_path, map_location="cpu") |
| | tag_to_idx = checkpoint.get("tag_to_idx") or {} |
| | num_labels = int(checkpoint.get("num_labels") or len(tag_to_idx)) |
| | model_name = checkpoint.get("model_name") or "DeepPavlov/rubert-base-cased" |
| | use_snippet = bool(checkpoint.get("use_snippet", False)) |
| |
|
| | model_id = args.model_id or ckpt_path.stem |
| |
|
| | logger.info(f"Loading data from {args.data_path}...") |
| | df_ria, _, _ = load_data(args.data_path) |
| |
|
| | logger.info("Processing text...") |
| | df_ria["title_clean"] = df_ria["title"].apply(normalise_text) |
| | if "snippet" in df_ria.columns: |
| | df_ria["snippet_clean"] = df_ria["snippet"].fillna("").apply(normalise_text) |
| |
|
| | logger.info("Processing tags...") |
| | df_ria["tags"] = process_tags(df_ria["tags"]) |
| |
|
| | logger.info("Splitting data...") |
| | df_train, df_val, df_test = split_data( |
| | df_ria, |
| | train_date_end="2018-10-01", |
| | val_date_start="2018-10-01", |
| | val_date_end="2018-12-01", |
| | test_date_start="2018-12-01", |
| | ) |
| |
|
| | protocol_meta: dict[str, Any] | None = None |
| | if args.protocol_dir: |
| | protocol_path = Path(args.protocol_dir) |
| | splits_path = protocol_path / "splits.json" |
| | mapping_path = protocol_path / "tag_to_idx.json" |
| | if not splits_path.exists() or not mapping_path.exists(): |
| | raise FileNotFoundError(f"protocol-dir must contain splits.json and tag_to_idx.json: {protocol_path}") |
| |
|
| | splits = json.loads(splits_path.read_text(encoding="utf-8")) |
| | id_col = splits.get("id_column", "href") |
| | if id_col == "href" and "href" in df_val.columns: |
| | df_train = df_train[df_train["href"].astype(str).isin(set(splits["train_ids"]))].copy() |
| | df_val = df_val[df_val["href"].astype(str).isin(set(splits["val_ids"]))].copy() |
| | df_test = df_test[df_test["href"].astype(str).isin(set(splits["test_ids"]))].copy() |
| | else: |
| | train_ids = set(splits["train_ids"]) |
| | val_ids = set(splits["val_ids"]) |
| | test_ids = set(splits["test_ids"]) |
| | df_train = df_train[df_train.index.astype(str).isin(train_ids)].copy() |
| | df_val = df_val[df_val.index.astype(str).isin(val_ids)].copy() |
| | df_test = df_test[df_test.index.astype(str).isin(test_ids)].copy() |
| |
|
| | tag_to_idx = json.loads(mapping_path.read_text(encoding="utf-8")) |
| | num_labels = len(tag_to_idx) |
| | logger.info( |
| | f"Loaded protocol bundle from {protocol_path} " |
| | f"(train={len(df_train)}, val={len(df_val)}, test={len(df_test)}, labels={num_labels})" |
| | ) |
| |
|
| | protocol_meta = { |
| | "data_path": args.data_path, |
| | "data_sha256": _file_sha256(args.data_path), |
| | "split": { |
| | "train_date_end": "2018-10-01", |
| | "val_date_start": "2018-10-01", |
| | "val_date_end": "2018-12-01", |
| | "test_date_start": "2018-12-01", |
| | }, |
| | "limits": { |
| | "max_train_samples": len(df_train), |
| | "max_val_samples": len(df_val), |
| | }, |
| | "label_space": { |
| | "min_tag_frequency": None, |
| | "num_labels": num_labels, |
| | }, |
| | } |
| |
|
| | else: |
| | if args.max_val_samples is not None: |
| | df_val = df_val.head(args.max_val_samples).copy() |
| |
|
| | logger.info(f"Val samples: {len(df_val)}") |
| |
|
| | |
| | df_val = df_val.copy() |
| | df_val["target_tags"] = create_target_encoding(df_val, tag_to_idx) |
| |
|
| | tokenizer = create_tokenizer(model_name, max_length=128) |
| | val_dataset = TransformerNewsDataset( |
| | df=df_val, |
| | tokenizer=tokenizer, |
| | max_title_len=128, |
| | max_snippet_len=256 if use_snippet else None, |
| | label_to_idx=tag_to_idx, |
| | ) |
| |
|
| | model = RussianNewsClassifier( |
| | model_name=model_name, |
| | num_labels=num_labels, |
| | dropout=float(checkpoint.get("dropout", 0.3)), |
| | use_snippet=use_snippet, |
| | freeze_bert=bool(checkpoint.get("freeze_backbone", False)), |
| | ) |
| | model.load_state_dict(checkpoint["state_dict"], strict=True) |
| |
|
| | device = _pick_device() |
| | logger.info(f"Evaluating on device: {device}") |
| |
|
| | probs, target, sample_ids = _predict_probs(model=model, dataset=val_dataset, batch_size=args.batch_size, device=device) |
| |
|
| | |
| | if args.output_csv: |
| | out_csv = Path(args.output_csv) |
| | out_csv.parent.mkdir(parents=True, exist_ok=True) |
| | data: dict[str, Any] = {"sample_id": sample_ids} |
| | for j in range(probs.shape[1]): |
| | data[f"class_{j}"] = probs[:, j].numpy() |
| | for j in range(target.shape[1]): |
| | data[f"target_class_{j}"] = target[:, j].numpy() |
| | pd.DataFrame(data).to_csv(out_csv, index=False) |
| | logger.info(f"Wrote predictions CSV: {out_csv}") |
| |
|
| | |
| | pred_default = (probs >= float(args.threshold)).float() |
| | metrics_default = _metrics_from_binary(target, pred_default) |
| |
|
| | |
| | sanity = { |
| | "avg_true_labels_per_sample": float(target.sum(dim=1).float().mean().item()), |
| | "avg_pred_labels_per_sample": float(pred_default.sum(dim=1).float().mean().item()), |
| | "pct_samples_with_any_true_label": float((target.sum(dim=1) > 0).float().mean().item()), |
| | "pct_samples_with_any_pred_label": float((pred_default.sum(dim=1) > 0).float().mean().item()), |
| | "prob_min": float(probs.min().item()) if probs.numel() else 0.0, |
| | "prob_mean": float(probs.mean().item()) if probs.numel() else 0.0, |
| | "prob_max": float(probs.max().item()) if probs.numel() else 0.0, |
| | } |
| |
|
| | payload: dict[str, Any] = { |
| | "experiment_id": model_id, |
| | "checkpoint_path": str(args.checkpoint), |
| | "data_path": args.data_path, |
| | "protocol_dir": args.protocol_dir, |
| | "protocol": protocol_meta, |
| | "threshold": float(args.threshold), |
| | "max_val_samples": args.max_val_samples, |
| | "val_samples": int(target.shape[0]), |
| | "num_labels": int(target.shape[1]), |
| | "model_name": model_name, |
| | "use_snippet": bool(use_snippet), |
| | "metrics": metrics_default, |
| | "sanity": sanity, |
| | } |
| |
|
| | if args.optimize_threshold: |
| | best_t, best_metrics = _optimize_threshold( |
| | probs=probs, |
| | target=target, |
| | metric=args.optimize_metric, |
| | min_t=0.01, |
| | max_t=0.99, |
| | step=0.01, |
| | ) |
| | payload["optimized_threshold"] = { |
| | "threshold": float(best_t), |
| | "metric": args.optimize_metric, |
| | "metric_value": float(best_metrics[args.optimize_metric]), |
| | **best_metrics, |
| | } |
| |
|
| | if args.metrics_json: |
| | out_json = Path(args.metrics_json) |
| | out_json.parent.mkdir(parents=True, exist_ok=True) |
| | out_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") |
| | logger.info(f"Wrote metrics JSON: {out_json}") |
| |
|
| | |
| | logger.info(f"Metrics @ threshold={args.threshold}: f1={metrics_default['f1']:.4f}, p={metrics_default['precision']:.4f}, r={metrics_default['recall']:.4f}") |
| | if args.optimize_threshold: |
| | opt = payload["optimized_threshold"] |
| | logger.info( |
| | f"Optimized threshold={opt['threshold']:.2f} ({opt['metric']}={opt['metric_value']:.4f}) " |
| | f"f1={opt['f1']:.4f}, p={opt['precision']:.4f}, r={opt['recall']:.4f}" |
| | ) |
| |
|
| | return 0 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | raise SystemExit(main()) |
| |
|
| |
|
| |
|
| |
|
| |
|