from typing import List from transformers import pipeline from pyvis.network import Network from functools import lru_cache from app import generate_graph import spacy from spacy import displacy DEFAULT_LABEL_COLORS = { "ORG": "#7aecec", "PRODUCT": "#bfeeb7", "GPE": "#feca74", "LOC": "#ff9561", "PERSON": "#aa9cfc", "NORP": "#c887fb", "FACILITY": "#9cc9cc", "EVENT": "#ffeb80", "LAW": "#ff8197", "LANGUAGE": "#ff8197", "WORK_OF_ART": "#f0d0ff", "DATE": "#bfe1d9", "TIME": "#bfe1d9", "MONEY": "#e4e7d2", "QUANTITY": "#e4e7d2", "ORDINAL": "#e4e7d2", "CARDINAL": "#e4e7d2", "PERCENT": "#e4e7d2", } def generate_knowledge_graph(texts: List[str], filename: str): nlp = spacy.load("en_core_web_sm") doc = nlp("\n".join(texts)) NERs = [ent.text for ent in doc.ents] NER_types = [ent.label_ for ent in doc.ents] for nr, nrt in zip(NERs, NER_types): print(nr, nrt) triplets = [] for triplet in texts: triplets.extend(generate_partial_graph(triplet)) print(generate_partial_graph.cache_info()) heads = [ t["head"] for t in triplets] tails = [ t["tail"] for t in triplets] nodes = set(heads + tails) net = Network(directed=True) for n in nodes: if n in NERs: NER_type = NER_types[NERs.index(n)] color = DEFAULT_LABEL_COLORS[NER_type] net.add_node(n, title=NER_type, shape="circle", color=color) else: net.add_node(n, shape="circle") unique_triplets = set() stringify_trip = lambda x : x["tail"] + x["head"] + x["type"] for triplet in triplets: if stringify_trip(triplet) not in unique_triplets: net.add_edge(triplet["tail"], triplet["head"], title=triplet["type"], label=triplet["type"]) unique_triplets.add(stringify_trip(triplet)) net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') net.show(filename) return nodes @lru_cache def generate_partial_graph(text): triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large') a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"] extracted_text = triplet_extractor.tokenizer.batch_decode(a) extracted_triplets = extract_triplets(extracted_text[0]) return extracted_triplets def extract_triplets(text): """ Function to parse the generated text and extract the triplets """ triplets = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' for token in text.replace("", "").replace("", "").replace("", "").split(): if token == "": current = 't' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) relation = '' subject = '' elif token == "": current = 's' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object_ != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) return triplets