import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import umap | |
matplotlib.use("Agg") | |
colormap = ( | |
np.array( | |
[ | |
[76, 255, 0], | |
[0, 127, 70], | |
[255, 0, 0], | |
[255, 217, 38], | |
[0, 135, 255], | |
[165, 0, 165], | |
[255, 167, 255], | |
[0, 255, 255], | |
[255, 96, 38], | |
[142, 76, 0], | |
[33, 0, 127], | |
[0, 0, 0], | |
[183, 183, 183], | |
], | |
dtype=float, | |
) | |
/ 255 | |
) | |
def plot_embeddings(embeddings, num_classes_in_batch): | |
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch | |
# if necessary get just the first 10 classes | |
if num_classes_in_batch > 10: | |
num_classes_in_batch = 10 | |
embeddings = embeddings[: num_classes_in_batch * num_utter_per_class] | |
model = umap.UMAP() | |
projection = model.fit_transform(embeddings) | |
ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class) | |
colors = [colormap[i] for i in ground_truth] | |
fig, ax = plt.subplots(figsize=(16, 10)) | |
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) | |
plt.gca().set_aspect("equal", "datalim") | |
plt.title("UMAP projection") | |
plt.tight_layout() | |
plt.savefig("umap") | |
return fig | |