Spaces:
Paused
Paused
File size: 3,092 Bytes
5325fcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import treetable as tt
from .._base_explorers import BaseExplorer
class LMExplorer(BaseExplorer):
eval_metrics: tp.List[str] = []
def stages(self) -> tp.List[str]:
return ['train', 'valid']
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table."""
return [
tt.group(
'train',
[
tt.leaf('epoch'),
tt.leaf('duration', '.1f'), # duration in minutes
tt.leaf('ping'),
tt.leaf('ce', '.4f'), # cross entropy
tt.leaf("ppl", '.3f'), # perplexity
],
align='>',
),
tt.group(
'valid',
[
tt.leaf('ce', '.4f'),
tt.leaf('ppl', '.3f'),
tt.leaf('best_ppl', '.3f'),
],
align='>',
),
]
def process_sheep(self, sheep, history):
parts = super().process_sheep(sheep, history)
track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
def comparator(mode, a, b):
return a < b if mode == 'lower' else a > b
for metrics in history:
for key, sub in metrics.items():
for metric in track_by:
# for the validation set, keep track of best metrics (ppl in this example)
# this is so we can conveniently compare metrics between runs in the grid
if key == 'valid' and metric in sub and comparator(
track_by[metric], sub[metric], best_metrics[metric]
):
best_metrics[metric] = sub[metric]
if 'valid' in parts:
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
return parts
class GenerationEvalExplorer(BaseExplorer):
eval_metrics: tp.List[str] = []
def stages(self) -> tp.List[str]:
return ['evaluate']
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table."""
return [
tt.group(
'evaluate',
[
tt.leaf('epoch', '.3f'),
tt.leaf('duration', '.1f'),
tt.leaf('ping'),
tt.leaf('ce', '.4f'),
tt.leaf('ppl', '.3f'),
tt.leaf('fad', '.3f'),
tt.leaf('kld', '.3f'),
tt.leaf('text_consistency', '.3f'),
tt.leaf('chroma_cosine', '.3f'),
],
align='>',
),
]
|