import torch.nn as nn import torch from transformers import AutoTokenizer import networkx as nx import plotly.graph_objects as go import random def find_similar_embeddings(target_embedding, n=10): """ Find the n most similar embeddings to the target embedding using cosine similarity Args: target_embedding: The embedding vector to compare against n: Number of similar embeddings to return (default 3) Returns: List of tuples containing (word, similarity_score) sorted by similarity """ # Convert target to tensor if not already if not isinstance(target_embedding, torch.Tensor): target_embedding = torch.tensor(target_embedding) # Get all embeddings from the model all_embeddings = model.embedding.weight # Compute cosine similarity between target and all embeddings similarities = torch.nn.functional.cosine_similarity( target_embedding.unsqueeze(0), all_embeddings ) # Get top n similar embeddings top_n_similarities, top_n_indices = torch.topk(similarities, n) # Convert to word-similarity pairs results = [] for idx, score in zip(top_n_indices, top_n_similarities): word = tokenizer.decode(idx) results.append((word, score.item())) return results def prompt_to_embeddings(prompt:str): # tokenize the input text tokens = tokenizer(prompt, return_tensors="pt") input_ids = tokens['input_ids'] # make a forward pass outputs = model(input_ids) # directly use the embeddings layer to get embeddings for the input_ids embeddings = outputs # print each token token_id_list = tokenizer.encode(prompt, add_special_tokens=True) token_str = [tokenizer.decode(t_id, skip_special_tokens=True) for t_id in token_id_list] return token_id_list, embeddings, token_str class EmbeddingModel(nn.Module): def __init__(self, vocab_size, embedding_dim): super(EmbeddingModel, self).__init__() self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim) def forward(self, input_ids): return self.embedding(input_ids) vocab_size = 151936 dimensions = 1536 embeddings_filename = r"python\code\files\embeddings_qwen.pth" tokenizer_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # Initialize the custom embedding model model = EmbeddingModel(vocab_size, dimensions) # Load the saved embeddings from the file saved_embeddings = torch.load(embeddings_filename) # Ensure the 'weight' key exists in the saved embeddings dictionary if 'weight' not in saved_embeddings: raise KeyError("The saved embeddings file does not contain 'weight' key.") embeddings_tensor = saved_embeddings['weight'] # Check if the dimensions match if embeddings_tensor.size() != (vocab_size, dimensions): raise ValueError(f"The dimensions of the loaded embeddings do not match the model's expected dimensions ({vocab_size}, {dimensions}).") # Assign the extracted embeddings tensor to the model's embedding layer model.embedding.weight.data = embeddings_tensor # put the model in eval mode model.eval() token_id_list, prompt_embeddings, prompt_token_str = prompt_to_embeddings("""We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely""") tokens_and_neighbors = {} for i in range(1, len(prompt_embeddings[0])): token_results = find_similar_embeddings(prompt_embeddings[0][i], n=40) similar_embs = [] for word, score in token_results: if word.strip().lower() != prompt_token_str[i].strip().lower(): similar_embs.append(word) tokens_and_neighbors[prompt_token_str[i]] = similar_embs all_token_embeddings = {} # Process each token and its neighbors for token, neighbors in tokens_and_neighbors.items(): # Get embedding for the original token token_id, token_emb, _ = prompt_to_embeddings(token) all_token_embeddings[token] = token_emb[0][1] # Get embeddings for each neighbor token for neighbor in neighbors: # Get embedding neighbor_id, neighbor_emb, _ = prompt_to_embeddings(neighbor) all_token_embeddings[neighbor] = neighbor_emb[0][1] # Create the graph G = nx.Graph() # Add edges from tokens to their neighbors for token, neighbors in tokens_and_neighbors.items(): for neighbor in neighbors: G.add_edge(token, neighbor) # Generate positions using spring layout with optimized parameters for atlas-like spread k = 2 # iterations = 200 # pos = nx.spring_layout(G, k=k) # Increased k for more spread # works on colab pos = nx.forceatlas2_layout(G, max_iter=36) # Define visualization dimensions viz_width = 1500 # Increased for better spread viz_height = 500 # Increased for better spread # Extract edge coordinates and scale them edge_x, edge_y = [], [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] # Scale coordinates to fill the width/height x0, x1 = x0 * viz_width, x1 * viz_width # Scale x coordinates y0, y1 = y0 * viz_height, y1 * viz_height # Scale y coordinates edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) # Node coordinates and data - scale the positions node_x = [pos[node][0] * viz_width for node in G.nodes()] node_y = [pos[node][1] * viz_height for node in G.nodes()] node_degrees = dict(G.degree()) # Assign colors using viridis colorscale colors = [] components = list(nx.connected_components(G)) # Create a mapping of nodes to their colors node_to_color = {} node_opacities = [] # List to store opacity values node_labels = [] # List to store node labels hover_labels = [] # List to store hover labels text_opacities = [] # List to store text opacities # Assign component index to each node for colorscale mapping node_component_indices = [] for node in G.nodes(): # Find which component the node belongs to for i, component in enumerate(components): if node in component: node_component_indices.append(i) break # Set opacity and label based on whether it's a main token or neighbor if node in tokens_and_neighbors: # Main token node_opacities.append(0.9) text_opacities.append(1.0) node_labels.append(node) hover_labels.append(node) else: # Neighbor token node_opacities.append(0.6) text_opacities.append(0.0) # Lower opacity for neighbor labels node_labels.append(node) # Show label with lower opacity hover_labels.append(node) node_sizes = [(degree + 5) * 1 for degree in node_degrees.values()] # Increased node sizes # Node trace with viridis colorscale node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', text=node_labels, # Show all labels textposition="top center", textfont=dict( color=[f'rgba(0,0,0,{opacity})' for opacity in text_opacities] # Set text opacity ), marker=dict( size=node_sizes, color=node_component_indices, colorscale='plasma', opacity=node_opacities, # Use the conditional opacities line_width=0.5 ), customdata=[[hover_labels[i], ' | '.join(G.neighbors(node))] for i, node in enumerate(G.nodes())], hovertemplate="%{customdata[0]}
Similar tokens: %{customdata[1]}", hoverlabel=dict(namelength=0) ) # Edge trace with black edges edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.5, color='grey'), # Set edge color to grey hoverinfo='none', mode='lines' ) # Set up Plotly figure fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( width=1200, height=400, paper_bgcolor='white', plot_bgcolor='white', showlegend=False, margin=dict(l=0, r=0, t=0, b=0), xaxis=dict( showgrid=False, zeroline=False, showticklabels=False, ), yaxis=dict( showgrid=False, zeroline=False, showticklabels=False, scaleanchor="x", scaleratio=1 ) )) fig.show() fig.write_html(r"src\fragments\token_visualization.html", include_plotlyjs=False, full_html=False, config={ 'displayModeBar': False, 'responsive': True, 'scrollZoom': False, }) ...