microembeddings / app.py
shreyask's picture
fix: robust text8 loading, gensim attribution in UI, training error handling
43fd8a7 verified
import gradio as gr
import numpy as np
import json
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from microembeddings import (
load_text8, build_vocab, prepare_corpus, build_neg_table,
train, normalize, most_similar, analogy, describe_text8_source
)
# --- Global state ---
state = {"W": None, "W_norm": None, "word2idx": None, "idx2word": None, "losses": []}
def load_pretrained():
"""Load pre-trained embeddings if available."""
try:
W = np.load("pretrained_W.npy")
with open("pretrained_vocab.json") as f:
meta = json.load(f)
vocab = meta["vocab"]
state["W"] = W
state["W_norm"] = normalize(W)
state["word2idx"] = {w: i for i, w in enumerate(vocab)}
state["idx2word"] = {i: w for i, w in enumerate(vocab)}
state["losses"] = meta.get("losses", [])
return (
"Loaded pre-trained full-text8 gensim vectors: "
f"{W.shape[0]} words x {W.shape[1]} dims"
)
except FileNotFoundError:
return "No pre-trained embeddings found. Train from scratch!"
# --- Tab 1: Train ---
def run_training(embed_dim, window_size, learning_rate, num_neg, progress=gr.Progress()):
fig = go.Figure()
try:
progress(0, desc="Loading text8...")
words = load_text8(500000)
word2idx, idx2word, freqs = build_vocab(words)
corpus = prepare_corpus(words, word2idx, freqs)
neg_dist = build_neg_table(freqs)
state["word2idx"] = word2idx
state["idx2word"] = idx2word
losses = []
def callback(epoch, i, total, loss):
pct = i / total
progress(pct, desc=f"Epoch {epoch+1}: loss={loss:.4f}")
losses.append(loss)
W, _ = train(corpus, len(word2idx), neg_dist,
epochs=3, embed_dim=int(embed_dim), lr=learning_rate,
window=int(window_size), num_neg=int(num_neg), callback=callback)
state["W"] = W
state["W_norm"] = normalize(W)
state["losses"] = losses
fig.add_trace(go.Scatter(y=losses, mode="lines", name="Loss",
line=dict(color="#4F46E5")))
fig.update_layout(title="Training Loss", xaxis_title="Step", yaxis_title="Loss",
template="plotly_white")
return fig, f"Done! {W.shape[0]} words x {W.shape[1]} dims"
except Exception as exc:
fig.update_layout(title="Training unavailable", template="plotly_white")
return fig, f"Training failed: {exc}"
# --- Tab 2: Explore ---
def explore_embeddings(method, num_words, category):
if state["W"] is None:
return None
n = min(int(num_words), len(state["idx2word"]))
W_sub = state["W"][:n]
words_sub = [state["idx2word"][i] for i in range(n)]
if method == "PCA":
coords = PCA(n_components=2).fit_transform(W_sub)
else:
coords = TSNE(n_components=2, perplexity=min(30, n - 1),
random_state=42).fit_transform(W_sub)
categories = {
"Countries": ["france", "germany", "italy", "spain", "china", "japan",
"india", "russia", "england", "canada", "brazil",
"australia", "mexico", "korea"],
"Animals": ["dog", "cat", "horse", "bird", "fish", "lion", "bear",
"wolf", "snake", "elephant"],
"Numbers": ["one", "two", "three", "four", "five", "six", "seven",
"eight", "nine", "ten"],
"Colors": ["red", "blue", "green", "yellow", "black", "white",
"brown", "gold", "silver"],
}
highlight_words = set(categories.get(category, []))
colors, sizes = [], []
for w in words_sub:
if w in highlight_words:
colors.append("#E11D48")
sizes.append(10)
else:
colors.append("#93C5FD")
sizes.append(4)
fig = go.Figure()
fig.add_trace(go.Scatter(
x=coords[:, 0], y=coords[:, 1], mode="markers",
marker=dict(size=sizes, color=colors, opacity=0.7),
text=words_sub, hoverinfo="text"
))
for i, w in enumerate(words_sub):
if w in highlight_words:
fig.add_annotation(x=coords[i, 0], y=coords[i, 1], text=w,
showarrow=False, yshift=12,
font=dict(size=10, color="#E11D48"))
fig.update_layout(title=f"Embedding Space ({method})", template="plotly_white",
width=800, height=600, showlegend=False)
return fig
# --- Tab 3: Analogies ---
def solve_analogy(a, b, c):
if state["W_norm"] is None:
return "Train or load embeddings first!", None
a, b, c = a.strip().lower(), b.strip().lower(), c.strip().lower()
results = analogy(a, b, c, state["W_norm"], state["word2idx"], state["idx2word"])
if not results:
missing = [w for w in [a, b, c] if w not in state["word2idx"]]
return f"Word(s) not in vocabulary: {', '.join(missing)}", None
text = f"{a} is to {b} as {c} is to...\n\n"
text += "\n".join(f" {w}: {s:.4f}" for w, s in results)
words_r, sims_r = zip(*results)
fig = go.Figure(go.Bar(x=list(sims_r), y=list(words_r), orientation="h",
marker_color="#4F46E5"))
fig.update_layout(title=f"{a} : {b} :: {c} : ?", xaxis_title="Cosine similarity",
template="plotly_white", yaxis=dict(autorange="reversed"))
return text, fig
# --- Tab 4: Nearest Neighbors ---
def find_neighbors(word):
if state["W_norm"] is None:
return "Train or load embeddings first!", None
word = word.strip().lower()
results = most_similar(word, state["W_norm"], state["word2idx"], state["idx2word"])
if not results:
return f"'{word}' not in vocabulary", None
text = "\n".join(f"{w}: {s:.4f}" for w, s in results)
words_r, sims_r = zip(*results)
fig = go.Figure(go.Bar(x=list(sims_r), y=list(words_r), orientation="h",
marker_color="#4F46E5"))
fig.update_layout(title=f"Nearest neighbors of '{word}'",
xaxis_title="Cosine similarity",
template="plotly_white", yaxis=dict(autorange="reversed"))
return text, fig
# --- Build UI ---
load_msg = load_pretrained()
corpus_msg = describe_text8_source()
with gr.Blocks(title="microembeddings", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# microembeddings\n"
"*Word2Vec skip-gram from scratch — train, explore, and play with word vectors*\n\n"
"Companion to the blog post: "
"[microembeddings: Understanding Word Vectors from Scratch]"
"(https://kshreyas.dev/post/microembeddings/)"
)
gr.Markdown(f"*{load_msg}*")
gr.Markdown(
"*Preloaded vectors use gensim Word2Vec on the full 17M-word text8 corpus.* "
"*The Train tab reruns the NumPy implementation on a 500k-word subset so it stays interactive.*"
)
with gr.Tabs():
with gr.Tab("Train"):
gr.Markdown(
"Train word embeddings from scratch on text8 (cleaned Wikipedia).\n\n"
f"{corpus_msg}"
)
with gr.Row():
dim_slider = gr.Slider(25, 100, value=50, step=25,
label="Embedding dimension")
win_slider = gr.Slider(1, 10, value=5, step=1, label="Window size")
with gr.Row():
lr_slider = gr.Slider(0.001, 0.05, value=0.025, step=0.001,
label="Learning rate")
neg_slider = gr.Slider(1, 15, value=5, step=1,
label="Negative samples")
train_btn = gr.Button("Train", variant="primary")
train_status = gr.Textbox(label="Status", interactive=False)
loss_plot = gr.Plot(label="Training Loss")
train_btn.click(run_training,
[dim_slider, win_slider, lr_slider, neg_slider],
[loss_plot, train_status])
with gr.Tab("Explore"):
gr.Markdown(
"Visualize the embedding space in 2D. "
"Similar words cluster together."
)
with gr.Row():
method_radio = gr.Radio(["PCA", "t-SNE"], value="PCA",
label="Projection method")
num_words_slider = gr.Slider(100, 500, value=200, step=50,
label="Number of words")
cat_dropdown = gr.Dropdown(
["None", "Countries", "Animals", "Numbers", "Colors"],
value="None", label="Highlight category"
)
explore_btn = gr.Button("Visualize", variant="primary")
explore_plot = gr.Plot(label="Embedding Space")
explore_btn.click(explore_embeddings,
[method_radio, num_words_slider, cat_dropdown],
explore_plot)
with gr.Tab("Analogies"):
gr.Markdown(
"Word vector arithmetic: **A is to B as C is to ?**\n\n"
"Computed as: `B - A + C ≈ ?`"
)
with gr.Row():
a_input = gr.Textbox(label="A", placeholder="man", value="man")
b_input = gr.Textbox(label="B", placeholder="king", value="king")
c_input = gr.Textbox(label="C", placeholder="woman", value="woman")
analogy_btn = gr.Button("Solve", variant="primary")
gr.Examples(
[["man", "king", "woman"], ["france", "paris", "germany"],
["big", "bigger", "small"]],
inputs=[a_input, b_input, c_input]
)
analogy_text = gr.Textbox(label="Results", interactive=False, lines=6)
analogy_plot = gr.Plot(label="Similarity")
analogy_btn.click(solve_analogy, [a_input, b_input, c_input],
[analogy_text, analogy_plot])
with gr.Tab("Nearest Neighbors"):
gr.Markdown("Find the most similar words by cosine similarity.")
word_input = gr.Textbox(label="Enter a word", placeholder="king")
nn_btn = gr.Button("Search", variant="primary")
nn_text = gr.Textbox(label="Results", interactive=False, lines=10)
nn_plot = gr.Plot(label="Similarity")
nn_btn.click(find_neighbors, word_input, [nn_text, nn_plot])
demo.launch()