Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import json | |
| import plotly.graph_objects as go | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import TSNE | |
| from microembeddings import ( | |
| load_text8, build_vocab, prepare_corpus, build_neg_table, | |
| train, normalize, most_similar, analogy, describe_text8_source | |
| ) | |
| # --- Global state --- | |
| state = {"W": None, "W_norm": None, "word2idx": None, "idx2word": None, "losses": []} | |
| def load_pretrained(): | |
| """Load pre-trained embeddings if available.""" | |
| try: | |
| W = np.load("pretrained_W.npy") | |
| with open("pretrained_vocab.json") as f: | |
| meta = json.load(f) | |
| vocab = meta["vocab"] | |
| state["W"] = W | |
| state["W_norm"] = normalize(W) | |
| state["word2idx"] = {w: i for i, w in enumerate(vocab)} | |
| state["idx2word"] = {i: w for i, w in enumerate(vocab)} | |
| state["losses"] = meta.get("losses", []) | |
| return ( | |
| "Loaded pre-trained full-text8 gensim vectors: " | |
| f"{W.shape[0]} words x {W.shape[1]} dims" | |
| ) | |
| except FileNotFoundError: | |
| return "No pre-trained embeddings found. Train from scratch!" | |
| # --- Tab 1: Train --- | |
| def run_training(embed_dim, window_size, learning_rate, num_neg, progress=gr.Progress()): | |
| fig = go.Figure() | |
| try: | |
| progress(0, desc="Loading text8...") | |
| words = load_text8(500000) | |
| word2idx, idx2word, freqs = build_vocab(words) | |
| corpus = prepare_corpus(words, word2idx, freqs) | |
| neg_dist = build_neg_table(freqs) | |
| state["word2idx"] = word2idx | |
| state["idx2word"] = idx2word | |
| losses = [] | |
| def callback(epoch, i, total, loss): | |
| pct = i / total | |
| progress(pct, desc=f"Epoch {epoch+1}: loss={loss:.4f}") | |
| losses.append(loss) | |
| W, _ = train(corpus, len(word2idx), neg_dist, | |
| epochs=3, embed_dim=int(embed_dim), lr=learning_rate, | |
| window=int(window_size), num_neg=int(num_neg), callback=callback) | |
| state["W"] = W | |
| state["W_norm"] = normalize(W) | |
| state["losses"] = losses | |
| fig.add_trace(go.Scatter(y=losses, mode="lines", name="Loss", | |
| line=dict(color="#4F46E5"))) | |
| fig.update_layout(title="Training Loss", xaxis_title="Step", yaxis_title="Loss", | |
| template="plotly_white") | |
| return fig, f"Done! {W.shape[0]} words x {W.shape[1]} dims" | |
| except Exception as exc: | |
| fig.update_layout(title="Training unavailable", template="plotly_white") | |
| return fig, f"Training failed: {exc}" | |
| # --- Tab 2: Explore --- | |
| def explore_embeddings(method, num_words, category): | |
| if state["W"] is None: | |
| return None | |
| n = min(int(num_words), len(state["idx2word"])) | |
| W_sub = state["W"][:n] | |
| words_sub = [state["idx2word"][i] for i in range(n)] | |
| if method == "PCA": | |
| coords = PCA(n_components=2).fit_transform(W_sub) | |
| else: | |
| coords = TSNE(n_components=2, perplexity=min(30, n - 1), | |
| random_state=42).fit_transform(W_sub) | |
| categories = { | |
| "Countries": ["france", "germany", "italy", "spain", "china", "japan", | |
| "india", "russia", "england", "canada", "brazil", | |
| "australia", "mexico", "korea"], | |
| "Animals": ["dog", "cat", "horse", "bird", "fish", "lion", "bear", | |
| "wolf", "snake", "elephant"], | |
| "Numbers": ["one", "two", "three", "four", "five", "six", "seven", | |
| "eight", "nine", "ten"], | |
| "Colors": ["red", "blue", "green", "yellow", "black", "white", | |
| "brown", "gold", "silver"], | |
| } | |
| highlight_words = set(categories.get(category, [])) | |
| colors, sizes = [], [] | |
| for w in words_sub: | |
| if w in highlight_words: | |
| colors.append("#E11D48") | |
| sizes.append(10) | |
| else: | |
| colors.append("#93C5FD") | |
| sizes.append(4) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=coords[:, 0], y=coords[:, 1], mode="markers", | |
| marker=dict(size=sizes, color=colors, opacity=0.7), | |
| text=words_sub, hoverinfo="text" | |
| )) | |
| for i, w in enumerate(words_sub): | |
| if w in highlight_words: | |
| fig.add_annotation(x=coords[i, 0], y=coords[i, 1], text=w, | |
| showarrow=False, yshift=12, | |
| font=dict(size=10, color="#E11D48")) | |
| fig.update_layout(title=f"Embedding Space ({method})", template="plotly_white", | |
| width=800, height=600, showlegend=False) | |
| return fig | |
| # --- Tab 3: Analogies --- | |
| def solve_analogy(a, b, c): | |
| if state["W_norm"] is None: | |
| return "Train or load embeddings first!", None | |
| a, b, c = a.strip().lower(), b.strip().lower(), c.strip().lower() | |
| results = analogy(a, b, c, state["W_norm"], state["word2idx"], state["idx2word"]) | |
| if not results: | |
| missing = [w for w in [a, b, c] if w not in state["word2idx"]] | |
| return f"Word(s) not in vocabulary: {', '.join(missing)}", None | |
| text = f"{a} is to {b} as {c} is to...\n\n" | |
| text += "\n".join(f" {w}: {s:.4f}" for w, s in results) | |
| words_r, sims_r = zip(*results) | |
| fig = go.Figure(go.Bar(x=list(sims_r), y=list(words_r), orientation="h", | |
| marker_color="#4F46E5")) | |
| fig.update_layout(title=f"{a} : {b} :: {c} : ?", xaxis_title="Cosine similarity", | |
| template="plotly_white", yaxis=dict(autorange="reversed")) | |
| return text, fig | |
| # --- Tab 4: Nearest Neighbors --- | |
| def find_neighbors(word): | |
| if state["W_norm"] is None: | |
| return "Train or load embeddings first!", None | |
| word = word.strip().lower() | |
| results = most_similar(word, state["W_norm"], state["word2idx"], state["idx2word"]) | |
| if not results: | |
| return f"'{word}' not in vocabulary", None | |
| text = "\n".join(f"{w}: {s:.4f}" for w, s in results) | |
| words_r, sims_r = zip(*results) | |
| fig = go.Figure(go.Bar(x=list(sims_r), y=list(words_r), orientation="h", | |
| marker_color="#4F46E5")) | |
| fig.update_layout(title=f"Nearest neighbors of '{word}'", | |
| xaxis_title="Cosine similarity", | |
| template="plotly_white", yaxis=dict(autorange="reversed")) | |
| return text, fig | |
| # --- Build UI --- | |
| load_msg = load_pretrained() | |
| corpus_msg = describe_text8_source() | |
| with gr.Blocks(title="microembeddings", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# microembeddings\n" | |
| "*Word2Vec skip-gram from scratch — train, explore, and play with word vectors*\n\n" | |
| "Companion to the blog post: " | |
| "[microembeddings: Understanding Word Vectors from Scratch]" | |
| "(https://kshreyas.dev/post/microembeddings/)" | |
| ) | |
| gr.Markdown(f"*{load_msg}*") | |
| gr.Markdown( | |
| "*Preloaded vectors use gensim Word2Vec on the full 17M-word text8 corpus.* " | |
| "*The Train tab reruns the NumPy implementation on a 500k-word subset so it stays interactive.*" | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Train"): | |
| gr.Markdown( | |
| "Train word embeddings from scratch on text8 (cleaned Wikipedia).\n\n" | |
| f"{corpus_msg}" | |
| ) | |
| with gr.Row(): | |
| dim_slider = gr.Slider(25, 100, value=50, step=25, | |
| label="Embedding dimension") | |
| win_slider = gr.Slider(1, 10, value=5, step=1, label="Window size") | |
| with gr.Row(): | |
| lr_slider = gr.Slider(0.001, 0.05, value=0.025, step=0.001, | |
| label="Learning rate") | |
| neg_slider = gr.Slider(1, 15, value=5, step=1, | |
| label="Negative samples") | |
| train_btn = gr.Button("Train", variant="primary") | |
| train_status = gr.Textbox(label="Status", interactive=False) | |
| loss_plot = gr.Plot(label="Training Loss") | |
| train_btn.click(run_training, | |
| [dim_slider, win_slider, lr_slider, neg_slider], | |
| [loss_plot, train_status]) | |
| with gr.Tab("Explore"): | |
| gr.Markdown( | |
| "Visualize the embedding space in 2D. " | |
| "Similar words cluster together." | |
| ) | |
| with gr.Row(): | |
| method_radio = gr.Radio(["PCA", "t-SNE"], value="PCA", | |
| label="Projection method") | |
| num_words_slider = gr.Slider(100, 500, value=200, step=50, | |
| label="Number of words") | |
| cat_dropdown = gr.Dropdown( | |
| ["None", "Countries", "Animals", "Numbers", "Colors"], | |
| value="None", label="Highlight category" | |
| ) | |
| explore_btn = gr.Button("Visualize", variant="primary") | |
| explore_plot = gr.Plot(label="Embedding Space") | |
| explore_btn.click(explore_embeddings, | |
| [method_radio, num_words_slider, cat_dropdown], | |
| explore_plot) | |
| with gr.Tab("Analogies"): | |
| gr.Markdown( | |
| "Word vector arithmetic: **A is to B as C is to ?**\n\n" | |
| "Computed as: `B - A + C ≈ ?`" | |
| ) | |
| with gr.Row(): | |
| a_input = gr.Textbox(label="A", placeholder="man", value="man") | |
| b_input = gr.Textbox(label="B", placeholder="king", value="king") | |
| c_input = gr.Textbox(label="C", placeholder="woman", value="woman") | |
| analogy_btn = gr.Button("Solve", variant="primary") | |
| gr.Examples( | |
| [["man", "king", "woman"], ["france", "paris", "germany"], | |
| ["big", "bigger", "small"]], | |
| inputs=[a_input, b_input, c_input] | |
| ) | |
| analogy_text = gr.Textbox(label="Results", interactive=False, lines=6) | |
| analogy_plot = gr.Plot(label="Similarity") | |
| analogy_btn.click(solve_analogy, [a_input, b_input, c_input], | |
| [analogy_text, analogy_plot]) | |
| with gr.Tab("Nearest Neighbors"): | |
| gr.Markdown("Find the most similar words by cosine similarity.") | |
| word_input = gr.Textbox(label="Enter a word", placeholder="king") | |
| nn_btn = gr.Button("Search", variant="primary") | |
| nn_text = gr.Textbox(label="Results", interactive=False, lines=10) | |
| nn_plot = gr.Plot(label="Similarity") | |
| nn_btn.click(find_neighbors, word_input, [nn_text, nn_plot]) | |
| demo.launch() | |