Spaces:
Sleeping
Sleeping
# %% | |
def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings): | |
head_embedding = entity_embeddings[head] | |
tail_embedding = entity_embeddings[tail] | |
relation_embeddings = relation_embeddings[relation] | |
distance = head_embedding + relation_embeddings - tail_embedding | |
return distance | |
def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10): | |
distances = [] | |
for i in range(len(entity_embeddings)): | |
distance = transe_distance(node, i, 0, entity_embeddings, relation_embeddings) | |
distances.append((i, distance)) | |
distances.sort(key=lambda x: x[1].norm().item()) | |
return distances[:top_n] | |
# %% | |
import pandas as pd | |
# Load the embeddings from the CSV files | |
entity_embeddings = pd.read_csv("entity_embeddings.csv", index_col=0) | |
# The embedding column is a string, convert it to a tensor | |
import torch | |
entity_embeddings["embedding"] = entity_embeddings["embedding"].apply( | |
lambda x: torch.tensor(eval(x)) | |
) | |
entity_embeddings.head() | |
# Now, load the relation embeddings | |
relation_embeddings = pd.read_csv("relation_embeddings.csv", index_col=0) | |
relation_embeddings["embedding"] = relation_embeddings["embedding"].apply( | |
lambda x: torch.tensor(eval(x)) | |
) | |
display(relation_embeddings.head()) | |
# %% | |
# Find the index of the entity with the uri "http://identifiers.org/medgen/C0002395" | |
head = entity_embeddings[ | |
entity_embeddings["uri"] == "http://identifiers.org/medgen/C0002395" | |
].index[0] | |
# Find the index of the entity with the uri "http://identifiers.org/medgen/C1843013" | |
tail = entity_embeddings[ | |
entity_embeddings["uri"] == "http://identifiers.org/medgen/C1843013" | |
].index[0] | |
relation = 0 | |
distance = transe_distance( | |
head, | |
tail, | |
relation, | |
entity_embeddings["embedding"], | |
relation_embeddings["embedding"], | |
) | |
print( | |
f'Distance between {entity_embeddings["label"][head]} ({head}) and {entity_embeddings["label"][tail]} ({tail}) via relation {relation_embeddings["label"][relation]} is {distance.norm().item()}' | |
) | |
# %% | |
# Calculate similar nodes to the head | |
similar_nodes = calculate_similar_nodes(head, entity_embeddings["embedding"], relation_embeddings["embedding"]) | |
print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):") | |
# Print the similar nodes | |
for i, (node, distance) in enumerate(similar_nodes): | |
print(f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}") | |
# %% |