maldv's picture
Upload folder using huggingface_hub
b59223f verified
raw
history blame
No virus
976 Bytes
# ztrain/stats.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
import os
import torch
from typing import Optional
def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]:
if base is None:
rebuilt = delta
else:
rebuilt = base + delta
norm = rebuilt.norm().item()
if base is None:
cosine = 0
else:
cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item()
min = delta.min().item()
max = delta.max().item()
del rebuilt
return norm, cosine, min, max
def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]):
norm, cosine, min, max = gen_stats(m0, None)
print(f"Base Model {norm} {min} {max}")
for i, s in enumerate(stack):
model_name = os.path.basename(model_list[i])
norm, cosine, min, max = gen_stats(s, m0)
print(f"{model_name} {norm} {cosine} {min} {max}")