File size: 2,075 Bytes
d380b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python3


import os
from argparse import ArgumentParser


def ssim_fid100_f1(metrics, fid_scale=100):
    ssim = metrics.loc['total', 'ssim']['mean']
    fid = metrics.loc['total', 'fid']['mean']
    fid_rel = max(0, fid_scale - fid) / fid_scale
    f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
    return f1


def find_best_checkpoint(model_list, models_dir):
    with open(model_list) as f:
        models = [m.strip() for m in f.readlines()]
    with open(f'{model_list}_best', 'w') as f:
        for model in models:
            print(model)
            best_f1 = 0
            best_epoch = 0
            best_step = 0
            with open(os.path.join(models_dir, model, 'train.log')) as fm:
                lines = fm.readlines()
                for line_index in range(len(lines)):
                    line = lines[line_index]
                    if 'Validation metrics after epoch' in line:
                        sharp_index = line.index('#')
                        cur_ep = line[sharp_index + 1:]
                        comma_index = cur_ep.index(',')
                        cur_ep = int(cur_ep[:comma_index])
                        total_index = line.index('total ')
                        step = int(line[total_index:].split()[1].strip())
                        total_line = lines[line_index + 5]
                        if not total_line.startswith('total'):
                            continue
                        words = total_line.strip().split()
                        f1 = float(words[-1])
                        print(f'\tEpoch: {cur_ep}, f1={f1}')
                        if f1 > best_f1:
                            best_f1 = f1
                            best_epoch = cur_ep
                            best_step = step
            f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n')


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('model_list')
    parser.add_argument('models_dir')
    args = parser.parse_args()
    find_best_checkpoint(args.model_list, args.models_dir)