File size: 3,092 Bytes
5325fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import treetable as tt

from .._base_explorers import BaseExplorer


class LMExplorer(BaseExplorer):
    eval_metrics: tp.List[str] = []

    def stages(self) -> tp.List[str]:
        return ['train', 'valid']

    def get_grid_metrics(self):
        """Return the metrics that should be displayed in the tracking table."""
        return [
            tt.group(
                'train',
                [
                    tt.leaf('epoch'),
                    tt.leaf('duration', '.1f'),  # duration in minutes
                    tt.leaf('ping'),
                    tt.leaf('ce', '.4f'),  # cross entropy
                    tt.leaf("ppl", '.3f'),  # perplexity
                ],
                align='>',
            ),
            tt.group(
                'valid',
                [
                    tt.leaf('ce', '.4f'),
                    tt.leaf('ppl', '.3f'),
                    tt.leaf('best_ppl', '.3f'),
                ],
                align='>',
            ),
        ]

    def process_sheep(self, sheep, history):
        parts = super().process_sheep(sheep, history)

        track_by = {'ppl': 'lower'}  # values should be in ['lower', 'higher']
        best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}

        def comparator(mode, a, b):
            return a < b if mode == 'lower' else a > b

        for metrics in history:
            for key, sub in metrics.items():
                for metric in track_by:
                    # for the validation set, keep track of best metrics (ppl in this example)
                    # this is so we can conveniently compare metrics between runs in the grid
                    if key == 'valid' and metric in sub and comparator(
                        track_by[metric], sub[metric], best_metrics[metric]
                    ):
                        best_metrics[metric] = sub[metric]

        if 'valid' in parts:
            parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
        return parts


class GenerationEvalExplorer(BaseExplorer):
    eval_metrics: tp.List[str] = []

    def stages(self) -> tp.List[str]:
        return ['evaluate']

    def get_grid_metrics(self):
        """Return the metrics that should be displayed in the tracking table."""
        return [
            tt.group(
                'evaluate',
                [
                    tt.leaf('epoch', '.3f'),
                    tt.leaf('duration', '.1f'),
                    tt.leaf('ping'),
                    tt.leaf('ce', '.4f'),
                    tt.leaf('ppl', '.3f'),
                    tt.leaf('fad', '.3f'),
                    tt.leaf('kld', '.3f'),
                    tt.leaf('text_consistency', '.3f'),
                    tt.leaf('chroma_cosine', '.3f'),
                ],
                align='>',
            ),
        ]