# %% def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings): head_embedding = entity_embeddings[head] tail_embedding = entity_embeddings[tail] relation_embeddings = relation_embeddings[relation] distance = head_embedding + relation_embeddings - tail_embedding return distance def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10): distances = [] for i in range(len(entity_embeddings)): distance = transe_distance(node, i, 0, entity_embeddings, relation_embeddings) distances.append((i, distance)) distances.sort(key=lambda x: x[1].norm().item()) return distances[:top_n] # %% import pandas as pd # Load the embeddings from the CSV files entity_embeddings = pd.read_csv("entity_embeddings.csv", index_col=0) # The embedding column is a string, convert it to a tensor import torch entity_embeddings["embedding"] = entity_embeddings["embedding"].apply( lambda x: torch.tensor(eval(x)) ) entity_embeddings.head() # Now, load the relation embeddings relation_embeddings = pd.read_csv("relation_embeddings.csv", index_col=0) relation_embeddings["embedding"] = relation_embeddings["embedding"].apply( lambda x: torch.tensor(eval(x)) ) display(relation_embeddings.head()) # %% # Find the index of the entity with the uri "http://identifiers.org/medgen/C0002395" head = entity_embeddings[ entity_embeddings["uri"] == "http://identifiers.org/medgen/C0002395" ].index[0] # Find the index of the entity with the uri "http://identifiers.org/medgen/C1843013" tail = entity_embeddings[ entity_embeddings["uri"] == "http://identifiers.org/medgen/C1843013" ].index[0] relation = 0 distance = transe_distance( head, tail, relation, entity_embeddings["embedding"], relation_embeddings["embedding"], ) print( f'Distance between {entity_embeddings["label"][head]} ({head}) and {entity_embeddings["label"][tail]} ({tail}) via relation {relation_embeddings["label"][relation]} is {distance.norm().item()}' ) # %% # Calculate similar nodes to the head similar_nodes = calculate_similar_nodes( head, entity_embeddings["embedding"], relation_embeddings["embedding"] ) print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):") # Print the similar nodes for i, (node, distance) in enumerate(similar_nodes): print( f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}" ) # %%