from typing import List from transformers import pipeline from pyvis.network import Network from functools import lru_cache import spacy from spacy import displacy import streamlit as st DEFAULT_LABEL_COLORS = { "ORG": "#7aecec", "PRODUCT": "#bfeeb7", "GPE": "#feca74", "LOC": "#ff9561", "PERSON": "#aa9cfc", "NORP": "#c887fb", "FACILITY": "#9cc9cc", "EVENT": "#ffeb80", "LAW": "#ff8197", "LANGUAGE": "#ff8197", "WORK_OF_ART": "#f0d0ff", "DATE": "#bfe1d9", "TIME": "#bfe1d9", "MONEY": "#e4e7d2", "QUANTITY": "#e4e7d2", "ORDINAL": "#e4e7d2", "CARDINAL": "#e4e7d2", "PERCENT": "#e4e7d2", } @st.experimental_singleton(max_entries=1) def get_pipeline(): triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large') return triplet_extractor @st.experimental_singleton(max_entries=1) def load_spacy(): nlp = spacy.load("en_core_web_sm") return nlp def generate_knowledge_graph(texts: List[str], filename: str): nlp = load_spacy() doc = nlp("\n".join(texts).lower()) NERs = [ent.text for ent in doc.ents] NER_types = [ent.label_ for ent in doc.ents] triplets = [] for triplet in texts: triplets.extend(generate_partial_graph(triplet)) heads = [ t["head"].lower() for t in triplets] tails = [ t["tail"].lower() for t in triplets] nodes = list(set(heads + tails)) net = Network(directed=True, width="700px", height="700px") for n in nodes: if n in NERs: NER_type = NER_types[NERs.index(n)] if NER_type in NER_types: if NER_type in DEFAULT_LABEL_COLORS.keys(): color = DEFAULT_LABEL_COLORS[NER_type] else: color = "#666666" net.add_node(n, title=NER_type, shape="circle", color=color) else: net.add_node(n, shape="circle") else: net.add_node(n, shape="circle") unique_triplets = set() stringify_trip = lambda x : x["tail"] + x["head"] + x["type"].lower() for triplet in triplets: if stringify_trip(triplet) not in unique_triplets: net.add_edge(triplet["head"].lower(), triplet["tail"].lower(), title=triplet["type"], label=triplet["type"]) unique_triplets.add(stringify_trip(triplet)) net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') net.show(filename) return nodes @lru_cache(maxsize=16) def generate_partial_graph(text: str): triplet_extractor = get_pipeline() a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"] extracted_text = triplet_extractor.tokenizer.batch_decode(a) extracted_triplets = extract_triplets(extracted_text[0]) return extracted_triplets def extract_triplets(text): """ Function to parse the generated text and extract the triplets """ triplets = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' for token in text.replace("", "").replace("", "").replace("", "").split(): if token == "": current = 't' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) relation = '' subject = '' elif token == "": current = 's' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object_ != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) return triplets