File size: 2,370 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json
import argparse
import json
from utils import *
from dataset import *
from metrics import *
from compute_correlations import compute_flickr
from compute_pascal50s import compute_pascal50S
from compute_foil import compute_foil

def collect_coef(memory, dataset_name, method, coef_tensor):
    memory.setdefault(dataset_name, {})
    coef = {k : round(float(v.numpy() if not isinstance(v,float) else v),4) for k, v in coef_tensor.items()}
    memory[dataset_name].update({method : coef})
    gprint(f"[{dataset_name}]",method,coef)

def compute_coef(args,memory,tops):
    dataset_name = "test"
    path = f"data_en/polaris/polaris_{dataset_name}.csv"
    yprint(f"Processing {dataset_name} ... (path: {path})")
    test_dataset = get_dataset(path)

    # mypolos
    if args.polos:
        polos_coef = compute_polos_coef(args,test_dataset,dataset_name,kendall_type='c')
        collect_coef(memory, dataset_name, "Polos", polos_coef)

    return memory, tops


def main(args):
    memory, tops = {}, {}
    if args.flickr:
        memory, tops = compute_flickr(args,args.model,memory,tops)
    if args.coef:
        memory, tops = compute_coef(args, memory, tops)
    if args.pascal:
        memory, tops = compute_pascal50S(args, memory, tops)
    if args.foil:
        memory, tops = compute_foil(args, memory, tops)

    with open("zeroshot_test_results.json", "w") as f:
        json.dump(memory, f, indent=4)
    
    yprint("[RESULTS]")
    gprint(json.dumps(memory, indent=4))

    rprint("[TOP]")
    for dataset_name, values in tops.items():
        rprint(f"> {dataset_name}")
        if isinstance(values,dict): # coef
            for kind, coef in values.items():
                rprint(f"{kind}: {coef[0]} ({coef[1]})")
        else: # acc
            method, acc = values
            rprint(f"{method} ({acc})")
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # models
    parser.add_argument('--model', default=None)
    parser.add_argument('--hparams',default=None)
    parser.add_argument('--polos', action='store_true')

    # benchmarks
    parser.add_argument('--coef', action='store_true')
    parser.add_argument('--flickr', action='store_true')
    parser.add_argument('--pascal', action='store_true')
    parser.add_argument('--foil', action='store_true')
    
    args = parser.parse_args()
    main(args)