unpairedelectron07 commited on
Commit
b0f16a9
·
verified ·
1 Parent(s): 442698d

Upload _base_explorers.py

Browse files
Files changed (1) hide show
  1. audiocraft/grids/_base_explorers.py +80 -0
audiocraft/grids/_base_explorers.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import time
9
+ import typing as tp
10
+ from dora import Explorer
11
+ import treetable as tt
12
+
13
+
14
+ def get_sheep_ping(sheep) -> tp.Optional[str]:
15
+ """Return the amount of time since the Sheep made some update
16
+ to its log. Returns a str using the relevant time unit."""
17
+ ping = None
18
+ if sheep.log is not None and sheep.log.exists():
19
+ delta = time.time() - sheep.log.stat().st_mtime
20
+ if delta > 3600 * 24:
21
+ ping = f'{delta / (3600 * 24):.1f}d'
22
+ elif delta > 3600:
23
+ ping = f'{delta / (3600):.1f}h'
24
+ elif delta > 60:
25
+ ping = f'{delta / 60:.1f}m'
26
+ else:
27
+ ping = f'{delta:.1f}s'
28
+ return ping
29
+
30
+
31
+ class BaseExplorer(ABC, Explorer):
32
+ """Base explorer for AudioCraft grids.
33
+
34
+ All task specific solvers are expected to implement the `get_grid_metrics`
35
+ method to specify logic about metrics to display for a given task.
36
+
37
+ If additional stages are used, the child explorer must define how to handle
38
+ these new stages in the `process_history` and `process_sheep` methods.
39
+ """
40
+ def stages(self):
41
+ return ["train", "valid", "evaluate"]
42
+
43
+ def get_grid_meta(self):
44
+ """Returns the list of Meta information to display for each XP/job.
45
+ """
46
+ return [
47
+ tt.leaf("index", align=">"),
48
+ tt.leaf("name", wrap=140),
49
+ tt.leaf("state"),
50
+ tt.leaf("sig", align=">"),
51
+ tt.leaf("sid", align="<"),
52
+ ]
53
+
54
+ @abstractmethod
55
+ def get_grid_metrics(self):
56
+ """Return the metrics that should be displayed in the tracking table.
57
+ """
58
+ ...
59
+
60
+ def process_sheep(self, sheep, history):
61
+ train = {
62
+ "epoch": len(history),
63
+ }
64
+ parts = {"train": train}
65
+ for metrics in history:
66
+ for key, sub in metrics.items():
67
+ part = parts.get(key, {})
68
+ if 'duration' in sub:
69
+ # Convert to minutes for readability.
70
+ sub['duration'] = sub['duration'] / 60.
71
+ part.update(sub)
72
+ parts[key] = part
73
+ ping = get_sheep_ping(sheep)
74
+ if ping is not None:
75
+ for name in self.stages():
76
+ if name not in parts:
77
+ parts[name] = {}
78
+ # Add the ping to each part for convenience.
79
+ parts[name]['ping'] = ping
80
+ return parts