Spaces:
Sleeping
Sleeping
| # %% | |
| def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings): | |
| head_embedding = entity_embeddings[head] | |
| tail_embedding = entity_embeddings[tail] | |
| relation_embeddings = relation_embeddings[relation] | |
| distance = head_embedding + relation_embeddings - tail_embedding | |
| return distance | |
| def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10): | |
| distances = [] | |
| for i in range(len(entity_embeddings)): | |
| distance = transe_distance(node, i, 0, entity_embeddings, relation_embeddings) | |
| distances.append((i, distance)) | |
| distances.sort(key=lambda x: x[1].norm().item()) | |
| return distances[:top_n] | |
| # %% | |
| import pandas as pd | |
| # Load the embeddings from the CSV files | |
| entity_embeddings = pd.read_csv("entity_embeddings.csv", index_col=0) | |
| # The embedding column is a string, convert it to a tensor | |
| import torch | |
| entity_embeddings["embedding"] = entity_embeddings["embedding"].apply( | |
| lambda x: torch.tensor(eval(x)) | |
| ) | |
| entity_embeddings.head() | |
| # Now, load the relation embeddings | |
| relation_embeddings = pd.read_csv("relation_embeddings.csv", index_col=0) | |
| relation_embeddings["embedding"] = relation_embeddings["embedding"].apply( | |
| lambda x: torch.tensor(eval(x)) | |
| ) | |
| display(relation_embeddings.head()) | |
| # %% | |
| # Find the index of the entity with the uri "http://identifiers.org/medgen/C0002395" | |
| head = entity_embeddings[ | |
| entity_embeddings["uri"] == "http://identifiers.org/medgen/C0002395" | |
| ].index[0] | |
| # Find the index of the entity with the uri "http://identifiers.org/medgen/C1843013" | |
| tail = entity_embeddings[ | |
| entity_embeddings["uri"] == "http://identifiers.org/medgen/C1843013" | |
| ].index[0] | |
| relation = 0 | |
| distance = transe_distance( | |
| head, | |
| tail, | |
| relation, | |
| entity_embeddings["embedding"], | |
| relation_embeddings["embedding"], | |
| ) | |
| print( | |
| f'Distance between {entity_embeddings["label"][head]} ({head}) and {entity_embeddings["label"][tail]} ({tail}) via relation {relation_embeddings["label"][relation]} is {distance.norm().item()}' | |
| ) | |
| # %% | |
| # Calculate similar nodes to the head | |
| similar_nodes = calculate_similar_nodes(head, entity_embeddings["embedding"], relation_embeddings["embedding"]) | |
| print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):") | |
| # Print the similar nodes | |
| for i, (node, distance) in enumerate(similar_nodes): | |
| print(f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}") | |
| # %% |