# ztrain/stats.py # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted import os import torch from typing import Optional def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]: if base is None: rebuilt = delta else: rebuilt = base + delta norm = rebuilt.norm().item() if base is None: cosine = 0 else: cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item() min = delta.min().item() max = delta.max().item() del rebuilt return norm, cosine, min, max def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]): norm, cosine, min, max = gen_stats(m0, None) print(f"Base Model {norm} {min} {max}") for i, s in enumerate(stack): model_name = os.path.basename(model_list[i]) norm, cosine, min, max = gen_stats(s, m0) print(f"{model_name} {norm} {cosine} {min} {max}")