|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import glob |
|
|
import json |
|
|
import re |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
BASE_DIR = "/pfs/lichenyi/work/evaluation" |
|
|
|
|
|
def collect_accuracies(base_dir: str): |
|
|
""" |
|
|
从 base_dir 下面的 valid_score_in_*.json 和 valid_score_ood_*.json 中 |
|
|
读取 summary.accuracy,返回两个 dict: |
|
|
in_acc[step] = accuracy |
|
|
ood_acc[step] = accuracy |
|
|
""" |
|
|
pattern = os.path.join(base_dir, "valid_score_*.json") |
|
|
files = glob.glob(pattern) |
|
|
|
|
|
in_acc = {} |
|
|
ood_acc = {} |
|
|
|
|
|
|
|
|
regex = re.compile(r"valid_score_(in|ood)_(\d+)\.json") |
|
|
|
|
|
for path in sorted(files): |
|
|
fname = os.path.basename(path) |
|
|
m = regex.match(fname) |
|
|
if not m: |
|
|
continue |
|
|
|
|
|
split = m.group(1) |
|
|
step = int(m.group(2)) |
|
|
|
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
acc = data.get("summary", {}).get("accuracy", None) |
|
|
if acc is None: |
|
|
continue |
|
|
|
|
|
if split == "in": |
|
|
in_acc[step] = acc |
|
|
else: |
|
|
ood_acc[step] = acc |
|
|
|
|
|
return in_acc, ood_acc |
|
|
|
|
|
|
|
|
def plot_accuracies(in_acc, ood_acc, out_path="valid_accuracy.png"): |
|
|
""" |
|
|
根据 in_acc 和 ood_acc 画图并保存为 out_path。 |
|
|
in_acc / ood_acc: dict[int, float] |
|
|
""" |
|
|
plt.figure(figsize=(8, 5)) |
|
|
|
|
|
|
|
|
if in_acc: |
|
|
steps_in = sorted(in_acc.keys()) |
|
|
vals_in = [in_acc[s] for s in steps_in] |
|
|
plt.plot(steps_in, vals_in, marker="o", label="in (ID)") |
|
|
|
|
|
|
|
|
if ood_acc: |
|
|
steps_ood = sorted(ood_acc.keys()) |
|
|
vals_ood = [ood_acc[s] for s in steps_ood] |
|
|
plt.plot(steps_ood, vals_ood, marker="s", linestyle="--", label="ood (OOD)") |
|
|
|
|
|
plt.xlabel("checkpoint / step") |
|
|
plt.ylabel("accuracy") |
|
|
plt.title("Validation Accuracy (in vs ood)") |
|
|
plt.grid(True, linestyle=":") |
|
|
plt.legend() |
|
|
plt.tight_layout() |
|
|
plt.savefig(out_path, dpi=300) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
in_acc, ood_acc = collect_accuracies(BASE_DIR) |
|
|
print("in-domain checkpoints and accuracies:", in_acc) |
|
|
print("ood checkpoints and accuracies:", ood_acc) |
|
|
plot_accuracies(in_acc, ood_acc, out_path=os.path.join(BASE_DIR, "valid_accuracy.png")) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|