klinic / graph.py
acmc
First commit in HuggingFace
93e1b64
raw
history blame
No virus
8.11 kB
# %%
import rdflib
import pandas as pd
def get_graph():
# File with the graph: MGCONSO.RRF
df_concepts = pd.read_csv("MGCONSO.RRF", sep="|", header=0)
# Rename the column '#CUI' to 'CUI'
df_concepts.rename(columns={"#CUI": "CUI"}, inplace=True)
# Remove the last column, it's empty
df_concepts = df_concepts.iloc[:, :-1]
print(df_concepts.head())
# Create a graph
g = rdflib.Graph()
# Bind the namespace
g.bind("medgen", "http://identifiers.org/medgen/")
# Iterate over the rows
for i, row in df_concepts.iterrows():
if row.SUPPRESS == "Y":
continue
if row.ISPREF == "Y" and row.STT == "PF" and row.TS == "P":
# Create the URI
uri = rdflib.URIRef(f"http://identifiers.org/medgen/{row.CUI}")
# Add the triple
g.add((uri, rdflib.RDFS.label, rdflib.Literal(row.STR)))
# Now, load MGREL.RRF
df_relations = pd.read_csv("MGREL.RRF", sep="|", header=0)
# Rename the column '#CUI1' to 'CUI1'
df_relations.rename(columns={"#CUI1": "CUI1"}, inplace=True)
# Remove the last column, it's empty
df_relations = df_relations.iloc[:, :-1]
print(df_relations.head())
# Iterate over the rows
for i, row in df_relations.iterrows():
if row.SUPPRESS == "Y":
continue
# Create the URI
uri1 = rdflib.URIRef(f"http://identifiers.org/medgen/{row.CUI1}")
uri2 = rdflib.URIRef(f"http://identifiers.org/medgen/{row.CUI2}")
# Add the triple
if row.REL == "RL":
g.add((uri1, rdflib.URIRef("related"), uri2))
continue
g.add((uri1, rdflib.URIRef(f"http://identifiers.org/medgen/{row.REL}"), uri2))
return g
def apply_rules_to_graph(g):
# Now, apply this rule: if two nodes have the same parent (i.e. node1 RB node2 and node3 RB node2, then node1 related node3)
# Query the graph to get the parents of each node
query = """
PREFIX medgen: <http://identifiers.org/medgen/>
SELECT DISTINCT ?parent ?child1 ?child2 WHERE {
?parent medgen:RN ?child1 .
?parent medgen:RN ?child2 .
FILTER (?child1 != ?child2)
}
"""
res = g.query(query)
for row in res:
g.add((row.child1, rdflib.URIRef("related"), row.child2))
g.add((row.child2, rdflib.URIRef("related"), row.child1))
return g
def get_labels_of_entities():
"""
Returns a dictionary with the labels of the entities
"""
# File with the graph: MGCONSO.RRF
df_concepts = pd.read_csv("MGCONSO.RRF", sep="|", header=0)
# Rename the column '#CUI' to 'CUI'
df_concepts.rename(columns={"#CUI": "CUI"}, inplace=True)
# Remove the last column, it's empty
df_concepts = df_concepts.iloc[:, :-1]
# Create a dictionary
labels_of_entities = {}
# Iterate over the rows
for i, row in df_concepts.iterrows():
if row.SUPPRESS == "Y":
continue
if row.ISPREF == "Y" and row.STT == "PF" and row.TS == "P":
labels_of_entities[f"http://identifiers.org/medgen/{row.CUI}"] = row.STR
return labels_of_entities
def generate_triples_file(graph: rdflib.Graph):
with open("triples_medgen.tsv", "w") as f:
# Output the triples ?s ?p ?o
for s, p, o in graph.triples((None, rdflib.URIRef("related"), None)):
f.write(f"{s}\t{p}\t{o}\n")
for s, p, o in graph.triples(
(None, rdflib.URIRef("http://identifiers.org/medgen/RN"), None)
):
f.write(f"{s}\t{p}\t{o}\n")
for s, p, o in graph.triples(
(None, rdflib.URIRef("http://identifiers.org/medgen/RB"), None)
):
f.write(f"{s}\t{p}\t{o}\n")
for s, p, o in graph.triples((None, rdflib.URIRef("http://identifiers.org/medgen/PAR"), None)):
f.write(f"{s}\t{p}\t{o}\n")
for s, p, o in graph.triples((None, rdflib.URIRef("http://identifiers.org/medgen/CHD"), None)):
f.write(f"{s}\t{p}\t{o}\n")
def save_adjacency_matrix():
# Load the triples file generated
df = pd.read_csv("triples_medgen.tsv", sep="\t", header=None)
# Now output the adjacency matrix, where the rows are the subjects and the columns are the objects
# The values are the relations (i.e. 0 if no relation and 1 if there is a relation)
# Get the unique subjects and objects
subjects = df[0].unique()
objects = df[2].unique()
# Create the adjacency matrix
adj_matrix = pd.DataFrame(0, index=subjects, columns=objects)
# Iterate over the rows
for i, row in df.iterrows():
adj_matrix.loc[row[0], row[2]] = 1
# Save the adjacency matrix
adj_matrix.to_csv("adjacency_matrix.mat", sep="\t")
# %%
g = get_graph()
# %%
g = apply_rules_to_graph(g)
# %%
labels_of_entities = get_labels_of_entities()
# %%
generate_triples_file(g)
# %%
from pykeen.triples import TriplesFactory
from pykeen.models import TuckER, TransE, TransH
from pykeen.pipeline import pipeline
tf = TriplesFactory.from_path("triples_medgen.tsv")
print(f"Triples count: {tf.num_triples}")
training, testing, validation = tf.split([0.8, 0.1, 0.1], random_state=42, randomize_cleanup=False)
result = pipeline(
training=training,
testing=testing,
validation=validation,
model=TransE,
stopper="early",
epochs=500, # short epochs for testing - you should go
# higher, especially with early stopper enabled
)
result.save_to_directory("doctests/test_unstratified_stopped_complex")
# %%
import torch
alzheimers = "http://identifiers.org/medgen/C1843013"
# What does the model predict for Alzheimer's disease?
model = result.model
alzheimers_id = tf.entity_to_id[alzheimers]
relation_id = tf.relation_to_id["related"]
batch_to_predict = torch.tensor([[alzheimers_id, relation_id]])
alzheimers_pred = model.predict_t(hr_batch=batch_to_predict)
print(alzheimers_pred.shape)
# Get the indices of the top 10 predictions
top10 = torch.topk(alzheimers_pred, 10, largest=True)
# Get the entities
entities = tf.entity_id_to_label
print(top10.indices)
for i in top10.indices[0]:
# Ask the graph, what is the label for this entity?
query = f"""
PREFIX medgen: <http://identifiers.org/medgen/>
SELECT ?label WHERE {{
<{entities[i.item()]}> <http://www.w3.org/2000/01/rdf-schema#label> ?label
}}
"""
res = g.query(query)
for i, row in enumerate(res):
print(f"{i}: {row}")
# %%
from pykeen.nn.representation import Embedding
# Get the embeddings of all the entities
entity_ids = torch.LongTensor(list(tf.entity_to_id.values())).cuda()
entity_embeddings: Embedding = model.entity_representations[0]._embeddings(entity_ids)
# Get the embeddings of the relations
relation_ids = torch.LongTensor(list(tf.relation_to_id.values())).cuda()
relation_embeddings: Embedding = model.relation_representations[0]._embeddings(
relation_ids
)
print(f"Entity embeddings shape: {entity_embeddings.shape}")
print(f"Relation embeddings shape: {relation_embeddings.shape}")
# Store the embeddings in a DataFrame
df = pd.DataFrame(
{
"embedding": entity_embeddings.detach().cpu().tolist(),
"label": [
labels_of_entities[tf.entity_id_to_label[i]] if tf.entity_id_to_label[i] in labels_of_entities else ""
for i in range(len(tf.entity_id_to_label))
],
"uri": [
f"{tf.entity_id_to_label[i]}" for i in range(len(tf.entity_id_to_label))
],
},
index=range(len(entity_embeddings)),
)
## Save the DataFrame
df.to_csv("entity_embeddings.csv")
# Store the embeddings in a DataFrame
df = pd.DataFrame(
{
"embedding": relation_embeddings.detach().cpu().tolist(),
"label": [
tf.relation_id_to_label[i] for i in range(len(tf.relation_id_to_label))
],
"uri": [
f"{tf.relation_id_to_label[i]}" for i in range(len(tf.relation_id_to_label))
],
},
index=range(len(relation_embeddings)),
)
## Save the DataFrame
df.to_csv("relation_embeddings.csv")
# %%
import pyobo
pyobo.get_name("mesh", "16793")
# %%