keithhon commited on
Commit
799d16e
·
1 Parent(s): e9e7667

Upload encoder/visualizations.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. encoder/visualizations.py +178 -0
encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from encoder import params_data
69
+ from encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+