semantrix_cond3 / display.py
Javierss
Add files
e2b757a
raw
history blame
5.08 kB
# %%
import asyncio
import pickle as pk
import time
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.art3d as art3d
import numpy as np
import torch
from matplotlib import cm
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Circle, PathPatch
from mpl_toolkits.mplot3d import Axes3D, axes3d
from sklearn.decomposition import PCA
warnings.filterwarnings("ignore", category=UserWarning)
# file_path = "word_embeddings_mpnet.pth"
# embeddings_dict = torch.load(file_path)
# # %%
# words = list(embeddings_dict.keys())
# sentences = [[word] for word in words]
# vectors = list(embeddings_dict.values())
# vectors_list = []
# for item in vectors:
# vectors_list.append(item.tolist())
# vector_list = vectors_list[:10]
# # %%
# # pca = PCA(n_components=3)
# # pca = pca.fit(vectors_list)
# # pk.dump(pca, open("pca_mpnet.pkl", "wb"))
# score = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
# %%
def display_words(words, vector_list, score, bold):
# %%
plt.ioff()
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
plt.rcParams["image.cmap"] = "magma"
colormap = cm.get_cmap("magma") # You can choose any colormap you like
# Normalize the float values to the range [0, 1]
score = np.array(score)
norm = plt.Normalize(0, 10) # type: ignore
colors = colormap(norm(score))
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.w_zaxis.set_pane_color(
(0.87, 0.91, 0.94, 0.8)
) # Set the z-axis face color (gray)
ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) # Transparent x-axis line
ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) # Transparent y-axis line
ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
# Turn off axis labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.grid(False)
# %%
data_pca = vector_list
if len(data_pca) > 1:
# for i in range(len(data_pca) - 1):
# data = np.append(
# data_pca,
# [norm_distance(data_pca[0], data_pca[i + 1], score[i + 1])],
# axis=0,
# )
# Create copies of the zero-th element of data_pca
data_pca0 = np.repeat(data_pca[0][None, :], len(data_pca) - 1, axis=0)
# Use these arrays to construct the calls to norm_distance_v
data = norm_distance_v(data_pca0, data_pca[1:], score[1:])
else:
data = data_pca.transpose()
(
x,
y,
z,
) = data
center_x = x[0]
center_y = y[0]
center_z = z[0]
# %%
ax.autoscale(enable=True, axis="both", tight=True)
# if bold == -1:
# k = len(words) - 1
# else:
# k = repeated
for i, word in enumerate(words):
if i == bold:
fontsize = "large"
fontweight = "demibold"
else:
fontsize = "medium"
fontweight = "normal"
ax.text(
x[i],
y[i],
z[i] + 0.05,
word,
fontsize=fontsize,
fontweight=fontweight,
alpha=1,
)
# ax.text(
# x[0],
# y[0],
# z[0] + 0.05,
# words[0],
# fontsize="medium",
# fontweight="normal",
# alpha=1,
# )
ax.scatter(x, y, z, c="black", marker="o", s=75, cmap="magma", vmin=0, vmax=10)
scatter = ax.scatter(
x,
y,
z,
marker="o",
s=70,
c=colors,
cmap="magma",
vmin=0,
vmax=10,
)
# cax = fig.add_subplot(gs[1, :]) # cb = plt.colorbar(sc, cax=cax)
# a = fig.colorbar(
# mappable=scatter,
# ax=ax,
# cmap="magma",
# norm=mpl.colors.Normalize(vmin=0, vmax=10),
# orientation="horizontal",
# )
fig.colorbar(
cm.ScalarMappable(norm=mpl.colors.Normalize(0, 10), cmap="magma"),
ax=ax,
orientation="horizontal",
)
# cbar.set_label("Score Values")
def update(frame):
distance = 0.5 * (score.max() - score.min())
ax.set_xlim(center_x - distance, center_x + distance)
ax.set_ylim(center_y - distance, center_y + distance)
ax.set_zlim(center_z - distance, center_z + distance)
ax.view_init(elev=20, azim=frame)
# %%
# Create the animation
frames = np.arange(0, 360, 5)
ani = FuncAnimation(fig, update, frames=frames, interval=120)
ani.save("3d_rotation.gif", writer="pillow", dpi=140)
plt.close(fig)
# %%
def norm_distance_v(orig, points, distances):
# Calculate the vector AB
AB = points - orig
# Calculate the normalized vector AB
Normalized_AB = AB / np.linalg.norm(AB, axis=1, keepdims=True)
# Specify the desired distance from point A
d = 10 - (distances.reshape(-1, 1) * 1)
# Calculate the new points C
C = orig + (Normalized_AB * d)
C = np.append([orig[0]], C, axis=0)
return np.array([C[:, 0], C[:, 1], C[:, 2]])