Knowledge-graphs / rebel.py
khaerens's picture
add free text and code clean up
f3c6abd
raw
history blame
No virus
4.05 kB
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
import spacy
from spacy import displacy
DEFAULT_LABEL_COLORS = {
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
def generate_knowledge_graph(texts: List[str], filename: str):
nlp = spacy.load("en_core_web_sm")
doc = nlp("\n".join(texts).lower())
NERs = [ent.text for ent in doc.ents]
NER_types = [ent.label_ for ent in doc.ents]
for nr, nrt in zip(NERs, NER_types):
print(nr, nrt)
triplets = []
for triplet in texts:
triplets.extend(generate_partial_graph(triplet))
print(generate_partial_graph.cache_info())
heads = [ t["head"].lower() for t in triplets]
tails = [ t["tail"].lower() for t in triplets]
nodes = list(set(heads + tails))
net = Network(directed=True, width="700px", height="700px")
for n in nodes:
if n in NERs:
NER_type = NER_types[NERs.index(n)]
if NER_type in NER_types:
color = DEFAULT_LABEL_COLORS[NER_type]
net.add_node(n, title=NER_type, shape="circle", color=color)
else:
net.add_node(n, shape="circle")
else:
net.add_node(n, shape="circle")
unique_triplets = set()
stringify_trip = lambda x : x["tail"] + x["head"] + x["type"].lower()
for triplet in triplets:
if stringify_trip(triplet) not in unique_triplets:
net.add_edge(triplet["head"].lower(), triplet["tail"].lower(), title=triplet["type"], label=triplet["type"])
unique_triplets.add(stringify_trip(triplet))
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show(filename)
return nodes
@lru_cache
def generate_partial_graph(text: str):
print(text[0:20], hash(text))
triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
extracted_text = triplet_extractor.tokenizer.batch_decode(a)
extracted_triplets = extract_triplets(extracted_text[0])
return extracted_triplets
def extract_triplets(text):
"""
Function to parse the generated text and extract the triplets
"""
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
return triplets