LiveFaceID / tools /pca.py
Martlgap's picture
testing on hugging
7c2803a
raw
history blame
1.72 kB
from sklearn.decomposition import PCA
import numpy as np
import plotly.express as px
def pca(matches, identities, gallery, dim=3):
"""
Perform PCA on embeddings.
Args:
embeddings: np.array of shape (n_embeddings, 512)
Returns:
embeddings_pca: np.array of shape (n_embeddings, 3)
"""
# Get Gallery and Detection Embeddings and stich them together in groups
embeddings = np.concatenate(
[[gallery[match.gallery_idx].embedding, identities[match.identity_idx].embedding] for match in matches],
axis=0,
)
# Get Identity Names and stich them together in groups
identity_names = np.concatenate(
[[gallery[match.gallery_idx].name, gallery[match.gallery_idx].name] for match in matches],
axis=0,
)
# Do 3D PCA
pca = PCA(n_components=dim)
pca.fit(embeddings)
embeddings_pca = pca.transform(embeddings)
if dim == 3:
fig = px.scatter_3d(
embeddings_pca,
x=0,
y=1,
z=2,
opacity=0.7,
color=identity_names,
color_discrete_sequence=px.colors.qualitative.Vivid,
)
fig.update_traces(marker=dict(size=4))
elif dim == 2:
fig = px.scatter(
embeddings_pca,
x=0,
y=1,
opacity=0.7,
color=identity_names,
color_discrete_sequence=px.colors.qualitative.Vivid,
)
fig.update_traces(marker=dict(size=4))
fig.update_xaxes(showgrid=True)
fig.update_yaxes(showgrid=True)
else:
raise ValueError("dim must be either 2 or 3")
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
return fig