| """ |
| Evaluate NisabaRelief on the validation set, optionally sweeping over step counts. |
| |
| Usage: |
| python evaluation.py # full dataset, num_steps=2 |
| python evaluation.py --sweep # subset, steps=[1,2,4,8] |
| """ |
|
|
| import argparse |
| import time |
| from datetime import timedelta |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image |
| from rich.console import Console, Group |
| from rich.live import Live |
| from rich.progress import ( |
| BarColumn, |
| MofNCompleteColumn, |
| Progress, |
| TextColumn, |
| TimeElapsedColumn, |
| ) |
| from rich.table import Table |
|
|
| from nisaba_relief import NisabaRelief |
| from util.metrics import compute_metrics, METRIC_NAMES, LABELS |
| from util.load_val_dataset import load_val_dataset |
|
|
|
|
| SWEEP_STEPS = [1, 2, 4, 8] |
| DEFAULT_STEPS = 2 |
| SWEEP_STRIDE = 4 |
| SWEEP_MAX = 175 |
| EVALS_DIR = Path(__file__).parent.parent / "data" / "evals" |
|
|
|
|
| def _eta(n_done: int, n_total: int, elapsed: float) -> str: |
| if n_done >= n_total > 0: |
| return "Done" |
| if n_done > 0: |
| return str(timedelta(seconds=int(elapsed / n_done * (n_total - n_done)))) |
| return "?" |
|
|
|
|
| def build_table( |
| results: dict, |
| n_done: int = 0, |
| n_total: int = 0, |
| elapsed: float = 0.0, |
| ) -> Table: |
| eta = _eta(n_done, n_total, elapsed) |
| steps = list(results.keys()) |
| table = Table(title=f"Results — ETA: {eta}") |
| table.add_column("Metric", style="bold") |
| for s in steps: |
| table.add_column(f"Steps={s}", justify="right") |
| for name in METRIC_NAMES: |
| cells = [] |
| for s in steps: |
| arr = np.array(results[s][name]) |
| if len(arr) == 0: |
| cells.append("—") |
| elif name in ("psnr", "psnr_hvsm", "sre"): |
| cells.append(f"{arr.mean():.2f} ± {arr.std():.2f} dB") |
| else: |
| cells.append(f"{arr.mean():.4f} ± {arr.std():.4f}") |
| table.add_row(LABELS[name], *cells) |
| return table |
|
|
|
|
| def load_grayscale(img: Image.Image) -> np.ndarray: |
| return np.array(img.convert("L")) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate NisabaRelief model") |
| parser.add_argument( |
| "--weights-dir", |
| default=".", |
| metavar="PATH", |
| help="path to weights directory (default: .)", |
| ) |
| parser.add_argument( |
| "--sweep", |
| action="store_true", |
| help="sweep over steps=[1,2,4,8] on a dataset subset", |
| ) |
| args = parser.parse_args() |
|
|
| rows = load_val_dataset() |
| if args.sweep: |
| rows = rows.select( |
| range(0, min(len(rows), SWEEP_MAX * SWEEP_STRIDE), SWEEP_STRIDE) |
| ) |
| steps_to_run = SWEEP_STEPS |
| else: |
| steps_to_run = [DEFAULT_STEPS] |
| results = {s: {m: [] for m in METRIC_NAMES} for s in steps_to_run} |
|
|
| model = NisabaRelief(seed=42, batch_size=4, weights_dir=Path(args.weights_dir)) |
|
|
| progress = Progress( |
| TextColumn("[progress.description]{task.description}"), |
| BarColumn(), |
| MofNCompleteColumn(), |
| TimeElapsedColumn(), |
| TextColumn("[cyan]{task.fields[hs_number]}"), |
| ) |
| task_desc = "Step Sweep" if args.sweep else "Evaluating" |
| task = progress.add_task(task_desc, total=len(rows), hs_number="") |
|
|
| start_time = time.monotonic() |
| with Live( |
| Group(progress, build_table(results)), |
| refresh_per_second=4, |
| transient=True, |
| ) as live: |
| for n_done, row in enumerate(rows): |
| progress.update(task, hs_number=row["hs_number"]) |
| gt = load_grayscale(row["msii"]) |
|
|
| for num_steps in steps_to_run: |
| model.num_steps = num_steps |
| save_name = f"{row['hs_number']}_photo_fullview_{int(row['variation']):02d}-step{num_steps}.png" |
| save_path = EVALS_DIR / save_name |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| if save_path.exists(): |
| pred_img = Image.open(save_path) |
| else: |
| pred_img = model.process(row["photo"], show_pbar=False) |
| pred_img.save(save_path) |
|
|
| pred = load_grayscale(pred_img) |
| pred_img.close() |
|
|
| if pred.shape != gt.shape: |
| pred = np.array( |
| Image.fromarray(pred).resize( |
| (gt.shape[1], gt.shape[0]), Image.LANCZOS |
| ) |
| ) |
|
|
| m = compute_metrics(pred, gt) |
| for name, val in m.items(): |
| results[num_steps][name].append(val) |
|
|
| elapsed = time.monotonic() - start_time |
| live.update( |
| Group(progress, build_table(results, n_done + 1, len(rows), elapsed)) |
| ) |
|
|
| progress.advance(task) |
|
|
| final_elapsed = time.monotonic() - start_time |
| Console().print(build_table(results, len(rows), len(rows), final_elapsed)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|