File size: 2,474 Bytes
93e1b64
 
 
 
 
 
 
 
a6bd112
93e1b64
 
 
 
 
 
 
 
a6bd112
93e1b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6bd112
 
 
93e1b64
 
 
a6bd112
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# %%
def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings):
    head_embedding = entity_embeddings[head]
    tail_embedding = entity_embeddings[tail]
    relation_embeddings = relation_embeddings[relation]
    distance = head_embedding + relation_embeddings - tail_embedding
    return distance


def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10):
    distances = []
    for i in range(len(entity_embeddings)):
        distance = transe_distance(node, i, 0, entity_embeddings, relation_embeddings)
        distances.append((i, distance))
    distances.sort(key=lambda x: x[1].norm().item())
    return distances[:top_n]


# %%
import pandas as pd

# Load the embeddings from the CSV files
entity_embeddings = pd.read_csv("entity_embeddings.csv", index_col=0)
# The embedding column is a string, convert it to a tensor
import torch

entity_embeddings["embedding"] = entity_embeddings["embedding"].apply(
    lambda x: torch.tensor(eval(x))
)
entity_embeddings.head()

# Now, load the relation embeddings
relation_embeddings = pd.read_csv("relation_embeddings.csv", index_col=0)
relation_embeddings["embedding"] = relation_embeddings["embedding"].apply(
    lambda x: torch.tensor(eval(x))
)
display(relation_embeddings.head())
# %%
# Find the index of the entity with the uri "http://identifiers.org/medgen/C0002395"
head = entity_embeddings[
    entity_embeddings["uri"] == "http://identifiers.org/medgen/C0002395"
].index[0]
# Find the index of the entity with the uri "http://identifiers.org/medgen/C1843013"
tail = entity_embeddings[
    entity_embeddings["uri"] == "http://identifiers.org/medgen/C1843013"
].index[0]
relation = 0
distance = transe_distance(
    head,
    tail,
    relation,
    entity_embeddings["embedding"],
    relation_embeddings["embedding"],
)
print(
    f'Distance between {entity_embeddings["label"][head]} ({head}) and {entity_embeddings["label"][tail]} ({tail}) via relation {relation_embeddings["label"][relation]} is {distance.norm().item()}'
)
# %%
# Calculate similar nodes to the head
similar_nodes = calculate_similar_nodes(
    head, entity_embeddings["embedding"], relation_embeddings["embedding"]
)
print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):")
# Print the similar nodes
for i, (node, distance) in enumerate(similar_nodes):
    print(
        f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}"
    )
# %%