Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import typing as tp | |
import treetable as tt | |
from .._base_explorers import BaseExplorer | |
class LMExplorer(BaseExplorer): | |
eval_metrics: tp.List[str] = [] | |
def stages(self) -> tp.List[str]: | |
return ['train', 'valid'] | |
def get_grid_metrics(self): | |
"""Return the metrics that should be displayed in the tracking table.""" | |
return [ | |
tt.group( | |
'train', | |
[ | |
tt.leaf('epoch'), | |
tt.leaf('duration', '.1f'), # duration in minutes | |
tt.leaf('ping'), | |
tt.leaf('ce', '.4f'), # cross entropy | |
tt.leaf("ppl", '.3f'), # perplexity | |
], | |
align='>', | |
), | |
tt.group( | |
'valid', | |
[ | |
tt.leaf('ce', '.4f'), | |
tt.leaf('ppl', '.3f'), | |
tt.leaf('best_ppl', '.3f'), | |
], | |
align='>', | |
), | |
] | |
def process_sheep(self, sheep, history): | |
parts = super().process_sheep(sheep, history) | |
track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher'] | |
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()} | |
def comparator(mode, a, b): | |
return a < b if mode == 'lower' else a > b | |
for metrics in history: | |
for key, sub in metrics.items(): | |
for metric in track_by: | |
# for the validation set, keep track of best metrics (ppl in this example) | |
# this is so we can conveniently compare metrics between runs in the grid | |
if key == 'valid' and metric in sub and comparator( | |
track_by[metric], sub[metric], best_metrics[metric] | |
): | |
best_metrics[metric] = sub[metric] | |
if 'valid' in parts: | |
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()}) | |
return parts | |
class GenerationEvalExplorer(BaseExplorer): | |
eval_metrics: tp.List[str] = [] | |
def stages(self) -> tp.List[str]: | |
return ['evaluate'] | |
def get_grid_metrics(self): | |
"""Return the metrics that should be displayed in the tracking table.""" | |
return [ | |
tt.group( | |
'evaluate', | |
[ | |
tt.leaf('epoch', '.3f'), | |
tt.leaf('duration', '.1f'), | |
tt.leaf('ping'), | |
tt.leaf('ce', '.4f'), | |
tt.leaf('ppl', '.3f'), | |
tt.leaf('fad', '.3f'), | |
tt.leaf('kld', '.3f'), | |
tt.leaf('text_consistency', '.3f'), | |
tt.leaf('chroma_cosine', '.3f'), | |
], | |
align='>', | |
), | |
] | |