unpairedelectron07
commited on
Commit
•
b0f16a9
1
Parent(s):
442698d
Upload _base_explorers.py
Browse files
audiocraft/grids/_base_explorers.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import time
|
9 |
+
import typing as tp
|
10 |
+
from dora import Explorer
|
11 |
+
import treetable as tt
|
12 |
+
|
13 |
+
|
14 |
+
def get_sheep_ping(sheep) -> tp.Optional[str]:
|
15 |
+
"""Return the amount of time since the Sheep made some update
|
16 |
+
to its log. Returns a str using the relevant time unit."""
|
17 |
+
ping = None
|
18 |
+
if sheep.log is not None and sheep.log.exists():
|
19 |
+
delta = time.time() - sheep.log.stat().st_mtime
|
20 |
+
if delta > 3600 * 24:
|
21 |
+
ping = f'{delta / (3600 * 24):.1f}d'
|
22 |
+
elif delta > 3600:
|
23 |
+
ping = f'{delta / (3600):.1f}h'
|
24 |
+
elif delta > 60:
|
25 |
+
ping = f'{delta / 60:.1f}m'
|
26 |
+
else:
|
27 |
+
ping = f'{delta:.1f}s'
|
28 |
+
return ping
|
29 |
+
|
30 |
+
|
31 |
+
class BaseExplorer(ABC, Explorer):
|
32 |
+
"""Base explorer for AudioCraft grids.
|
33 |
+
|
34 |
+
All task specific solvers are expected to implement the `get_grid_metrics`
|
35 |
+
method to specify logic about metrics to display for a given task.
|
36 |
+
|
37 |
+
If additional stages are used, the child explorer must define how to handle
|
38 |
+
these new stages in the `process_history` and `process_sheep` methods.
|
39 |
+
"""
|
40 |
+
def stages(self):
|
41 |
+
return ["train", "valid", "evaluate"]
|
42 |
+
|
43 |
+
def get_grid_meta(self):
|
44 |
+
"""Returns the list of Meta information to display for each XP/job.
|
45 |
+
"""
|
46 |
+
return [
|
47 |
+
tt.leaf("index", align=">"),
|
48 |
+
tt.leaf("name", wrap=140),
|
49 |
+
tt.leaf("state"),
|
50 |
+
tt.leaf("sig", align=">"),
|
51 |
+
tt.leaf("sid", align="<"),
|
52 |
+
]
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def get_grid_metrics(self):
|
56 |
+
"""Return the metrics that should be displayed in the tracking table.
|
57 |
+
"""
|
58 |
+
...
|
59 |
+
|
60 |
+
def process_sheep(self, sheep, history):
|
61 |
+
train = {
|
62 |
+
"epoch": len(history),
|
63 |
+
}
|
64 |
+
parts = {"train": train}
|
65 |
+
for metrics in history:
|
66 |
+
for key, sub in metrics.items():
|
67 |
+
part = parts.get(key, {})
|
68 |
+
if 'duration' in sub:
|
69 |
+
# Convert to minutes for readability.
|
70 |
+
sub['duration'] = sub['duration'] / 60.
|
71 |
+
part.update(sub)
|
72 |
+
parts[key] = part
|
73 |
+
ping = get_sheep_ping(sheep)
|
74 |
+
if ping is not None:
|
75 |
+
for name in self.stages():
|
76 |
+
if name not in parts:
|
77 |
+
parts[name] = {}
|
78 |
+
# Add the ping to each part for convenience.
|
79 |
+
parts[name]['ping'] = ping
|
80 |
+
return parts
|