MesseMMP
Normalize and trim comments
75c3625
import os
import json
import tqdm
import torch
import numpy as np
import click
from datetime import datetime
import lightning.pytorch as pl
import sklearn.metrics as skm
from torch.utils.data import DataLoader
from torchvision.transforms import transforms as T
from torchvision.transforms._transforms_video import ToTensorVideo
from pytorchvideo.transforms import Normalize
from full_model.rnn_dataset import SyntaxDataset
from full_model.rnn_model import SyntaxLightningModule
from inference.metrics_visualization import visualize_final_syntax_plotly_multi
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"DEVICE: {DEVICE}")
def safe_sample_std(values):
"""Sample std (ddof=1). Return 0.0 for one value or an empty input."""
arr = np.array(values, dtype=float)
if arr.size <= 1:
return 0.0
return float(arr.std(ddof=1))
def compute_metrics(y_true, y_pred, thr=22.0):
"""Pearson and mean recall."""
y_true_arr = np.array(y_true, dtype=float)
y_pred_arr = np.array(y_pred, dtype=float)
pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0
y_true_bin = (y_true_arr >= thr).astype(int)
y_pred_bin = (y_pred_arr >= thr).astype(int)
unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin]))
mean_recall = float(
np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))
) if len(unique_classes) > 1 else 0.0
return pearson, mean_recall
@click.command()
@click.option("-d", "--dataset-paths", multiple=True,
help="Dataset metadata JSON files (relative to dataset_root).")
@click.option("-n", "--dataset-names", multiple=True,
help="Dataset names for metrics and plots.")
@click.option("-p", "--postfixes", multiple=True,
help="Suffixes for prediction files.")
@click.option(
"-r",
"--dataset-root",
type=click.Path(exists=True),
default=".",
show_default=True,
help="Dataset root containing the JSON and DICOM files.",
)
@click.option(
"--model-dir",
type=click.Path(exists=True),
default="full_model_weights",
show_default=True,
help="Directory with .pt/.ckpt full-model weights (RNN head + backbone).",
)
@click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256),
show_default=True, help="Video size (H, W).")
@click.option("--frames-per-clip", type=int, default=32,
show_default=True, help="Frames per clip.")
@click.option("--num-workers", type=int, default=8,
show_default=True, help="Number of DataLoader workers.")
@click.option("--seed", type=int, default=42,
show_default=True, help="Random seed.")
@click.option(
"--pt-weights-format",
is_flag=True,
default=True,
show_default=True,
help="Full-model weight format: True uses .pt raw state_dict, False uses Lightning .ckpt.",
)
@click.option("--use-scaling", is_flag=True, default=False,
show_default=True, help="Apply a*x+b scaling from JSON.")
@click.option("--scaling-file",
help="JSON file with scaling coefficients (relative to dataset_root).")
@click.option(
"--variant",
type=str,
default="lstm_mean",
show_default=True,
help="Head-model variant: mean, lstm_mean, lstm_last, gru_mean, gru_last, bert_mean, bert_cls, bert_cls2.",
)
@click.option("-e", "--ensemble-name",
help="Ensemble name in metrics.json.")
@click.option("-m", "--metrics-file",
help="JSON file with experiment metrics.")
def main(dataset_paths, dataset_names, postfixes, dataset_root, model_dir, video_size,
frames_per_clip, num_workers, seed, pt_weights_format, use_scaling,
scaling_file, variant, ensemble_name, metrics_file):
pl.seed_everything(seed)
postfix_plotly = "Ensemble"
model_paths = {
"left": [
os.path.join(model_dir, f"LeftBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
for fold in range(5)
],
"right": [
os.path.join(model_dir, f"RightBinSyntax_R3D_fold{fold:02d}_{variant}_post_best.pt")
for fold in range(5)
],
}
scaling_params_dict = {}
if use_scaling:
postfix_plotly += "_scaled"
ensemble_name += "_scaled"
scaling_path = os.path.join(dataset_root, scaling_file)
if os.path.exists(scaling_path):
with open(scaling_path, "r") as f:
scaling_params_dict = json.load(f)
print(f"Loaded scaling from {scaling_path}")
else:
print(f"⚠️ Scaling file not found: {scaling_path}")
ensemble_results = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"use_scaling": use_scaling,
"pt_weights_format": pt_weights_format,
"variant": variant,
"datasets": {},
}
all_datasets, all_pearson, all_recalls = {}, {}, {}
def get_ab(i: int):
params = scaling_params_dict.get(f"fold{i}", (1.0, 0.0))
if isinstance(params, dict):
return params.get("a", 1.0), params.get("b", 0.0)
return params[0], params[1]
for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes):
abs_dataset_path = os.path.join(dataset_root, dataset_path)
results_file = os.path.join("results", f"{postfix}.json")
if os.path.exists(results_file):
print(f"[{postfix}] Loading from {results_file}")
with open(results_file, "r") as f:
data = json.load(f)
syntax_true = data["syntax_true"]
left_preds_all = data["left_preds"]
right_preds_all = data["right_preds"]
else:
print(f"[{postfix}] Computing predictions...")
left_preds_all, left_sids = run_artery(
abs_dataset_path, "left", model_paths["left"],
video_size, frames_per_clip, num_workers,
variant=variant, pt_weights_format=pt_weights_format,
)
right_preds_all, right_sids = run_artery(
abs_dataset_path, "right", model_paths["right"],
video_size, frames_per_clip, num_workers,
variant=variant, pt_weights_format=pt_weights_format,
)
assert left_sids == right_sids
with open(abs_dataset_path, "r") as f:
dataset = json.load(f)
syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset]
os.makedirs(os.path.dirname(results_file), exist_ok=True)
save_data = {
"syntax_true": syntax_true,
"left_preds": left_preds_all,
"right_preds": right_preds_all,
}
with open(results_file, "w") as f:
json.dump(save_data, f)
print(f"[{postfix}] Saved to {results_file}")
if use_scaling:
syntax_pred = []
for l_list, r_list in zip(left_preds_all, right_preds_all):
scaled_folds = []
for i, (l_val, r_val) in enumerate(zip(l_list, r_list)):
s = l_val + r_val
a, b = get_ab(i)
scaled_folds.append(a * s + b)
syntax_pred.append(max(0.0, float(np.mean(scaled_folds))))
else:
syntax_pred = [
max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)])))
for l_list, r_list in zip(left_preds_all, right_preds_all)
]
pearson, mean_recall = compute_metrics(syntax_true, syntax_pred)
print(f"[{postfix}] ENSEMBLE: Pearson={pearson:.4f}, Recall={mean_recall:.4f}")
n_folds = len(left_preds_all[0]) if left_preds_all else 0
fold_metrics = {metric: [] for metric in ["Pearson", "Mean_Recall"]}
for k in range(n_folds):
pred_k = []
for l_list, r_list in zip(left_preds_all, right_preds_all):
s = l_list[k] + r_list[k]
if use_scaling:
a, b = get_ab(k)
s = a * s + b
pred_k.append(max(0.0, float(s)))
fold_pearson, fold_recall = compute_metrics(syntax_true, pred_k)
for metric, value in zip(
fold_metrics.keys(),
[fold_pearson, fold_recall],
):
fold_metrics[metric].append(value)
fold_summary = {
k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v}
for k, v in fold_metrics.items()
}
all_datasets[dataset_name] = (syntax_true, syntax_pred)
all_pearson[dataset_name] = pearson
all_recalls[dataset_name] = mean_recall
ensemble_results["datasets"][dataset_name] = {
"Pearson": round(pearson, 4),
"Mean_Recall": round(mean_recall, 4),
"N_samples": len(syntax_true),
**{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()},
**{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()},
**{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()},
}
metrics_path = metrics_file
full_history = {}
if os.path.exists(metrics_path):
try:
with open(metrics_path, "r") as f:
full_history = json.load(f)
except json.JSONDecodeError:
print("⚠️ Metrics file corrupted. Creating new.")
full_history[ensemble_name] = ensemble_results
with open(metrics_path, "w") as f:
json.dump(full_history, f, indent=4)
print(f"✅ Metrics saved: {metrics_path}")
visualize_final_syntax_plotly_multi(
datasets=all_datasets,
r2_values=all_pearson,
gt_row="ENSEMBLE",
postfix=postfix_plotly,
recall_values=all_recalls,
)
def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip,
num_workers, variant: str, pt_weights_format: bool):
"""Inference for one artery across five folds."""
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
test_transform = T.Compose([
ToTensorVideo(),
T.Resize(size=video_size, antialias=True),
Normalize(mean=imagenet_mean, std=imagenet_std),
])
val_set = SyntaxDataset(
root=os.path.dirname(dataset_path),
meta=dataset_path,
train=False,
length=frames_per_clip,
label="",
artery=artery,
inference=True,
transform=test_transform,
)
val_loader = DataLoader(
val_set,
batch_size=1,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
)
print(f"{artery} artery: {len(val_loader)} samples")
models = []
for path in model_paths:
if not os.path.exists(path):
print(f"⚠️ Model not found: {path}")
continue
model = SyntaxLightningModule(
num_classes=2,
lr=1e-5,
variant=variant,
weight_decay=0.001,
max_epochs=1,
weight_path=None,
pl_weight_path=path,
pt_weights_format=pt_weights_format,
)
model.to(DEVICE)
model.eval()
models.append(model)
if not models:
raise RuntimeError(f"No models loaded for {artery}")
preds_all, sids = [], []
with torch.no_grad():
for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"):
if len(x.shape) == 1:
val_syntax_list = [0.0] * len(models)
else:
x = x.to(DEVICE)
val_syntax_list = []
for model in models:
y_hat = model(x)
yp_reg = y_hat[:, 1:]
val_log = yp_reg.squeeze(-1)
val = float(torch.exp(val_log).cpu()) - 1.0
val_syntax_list.append(val)
preds_all.append(val_syntax_list)
sids.append(sid)
return preds_all, sids
if __name__ == "__main__":
main()