File size: 1,949 Bytes
fe84f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
from collections import defaultdict
from pathlib import Path

import numpy as np
import treetable as tt

LOGS = Path("logs")
STD_KEY = "seed"
METRIC = "best"

parser = argparse.ArgumentParser("result_table.py")
parser.add_argument("-p",
                    "--paper",
                    action="store_true",
                    help="show results from the paper experiment")
parser.add_argument("-i", "--individual", action="store_true", help="no aggregation by seed")
args = parser.parse_args()

if args.paper:
    LOGS = Path("results/logs")

all_stats = defaultdict(list)

for path in LOGS.iterdir():
    if path.suffix == ".json" and (args.paper or path.with_suffix(".done").exists()):
        metric = json.load(open(path))[-1][METRIC]
        name = path.stem
        model = "Demucs"
        if "tasnet" in name:
            model = "Tasnet"
        if name == "default":
            parts = []
        else:
            parts = [p.split("=") for p in name.split(" ") if "--tasnet" not in p]
        if not args.individual:
            parts = [(k, v) for k, v in parts if k != STD_KEY]
        name = model + " " + " ".join(f"{k}={v}" for k, v in parts)
        all_stats[name].append(metric)

metrics = [tt.leaf("score", ".4f"), tt.leaf("std", ".3f"), tt.leaf("count", ".2f")]

mytable = tt.table([tt.leaf("name"), tt.group("valid", metrics)])

lines = []
for name, stats in all_stats.items():
    line = {"name": name}
    stats = np.array(stats)
    line["valid"] = {
        "score": stats.mean(),
        "std": stats.std() / stats.shape[0]**0.5,
        "count": stats.shape[0]
    }
    lines.append(line)
lines.sort(key=lambda x: x["valid"]["score"])
print(tt.treetable(lines, mytable, colors=['33', '0']))