Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	| from typing import List | |
| from transformers import pipeline | |
| from pyvis.network import Network | |
| from functools import lru_cache | |
| import spacy | |
| DEFAULT_LABEL_COLORS = { | |
| "ORG": "#7aecec", | |
| "PRODUCT": "#bfeeb7", | |
| "GPE": "#feca74", | |
| "LOC": "#ff9561", | |
| "PERSON": "#aa9cfc", | |
| "NORP": "#c887fb", | |
| "FACILITY": "#9cc9cc", | |
| "EVENT": "#ffeb80", | |
| "LAW": "#ff8197", | |
| "LANGUAGE": "#ff8197", | |
| "WORK_OF_ART": "#f0d0ff", | |
| "DATE": "#bfe1d9", | |
| "TIME": "#bfe1d9", | |
| "MONEY": "#e4e7d2", | |
| "QUANTITY": "#e4e7d2", | |
| "ORDINAL": "#e4e7d2", | |
| "CARDINAL": "#e4e7d2", | |
| "PERCENT": "#e4e7d2", | |
| } | |
| def generate_knowledge_graph(texts: List[str], filename: str): | |
| nlp = spacy.load("en_core_web_sm") | |
| doc = nlp("\n".join(texts).lower()) | |
| NERs = [ent.text for ent in doc.ents] | |
| NER_types = [ent.label_ for ent in doc.ents] | |
| triplets = [] | |
| for triplet in texts: | |
| triplets.extend(generate_partial_graph(triplet)) | |
| heads = [t["head"].lower() for t in triplets] | |
| tails = [t["tail"].lower() for t in triplets] | |
| nodes = list(set(heads + tails)) | |
| net = Network(directed=True, width="700px", height="700px") | |
| for n in nodes: | |
| if n in NERs: | |
| NER_type = NER_types[NERs.index(n)] | |
| if NER_type in NER_types: | |
| if NER_type in DEFAULT_LABEL_COLORS.keys(): | |
| color = DEFAULT_LABEL_COLORS[NER_type] | |
| else: | |
| color = "#666666" | |
| net.add_node(n, title=NER_type, shape="circle", color=color) | |
| else: | |
| net.add_node(n, shape="circle") | |
| else: | |
| net.add_node(n, shape="circle") | |
| unique_triplets = set() | |
| def stringify_trip(x): return x["tail"] + x["head"] + x["type"].lower() | |
| for triplet in triplets: | |
| if stringify_trip(triplet) not in unique_triplets: | |
| net.add_edge(triplet["head"].lower(), triplet["tail"].lower(), | |
| title=triplet["type"], label=triplet["type"]) | |
| unique_triplets.add(stringify_trip(triplet)) | |
| net.repulsion( | |
| node_distance=200, | |
| central_gravity=0.2, | |
| spring_length=200, | |
| spring_strength=0.05, | |
| damping=0.09 | |
| ) | |
| net.set_edge_smooth('dynamic') | |
| net.show(filename) | |
| return nodes | |
| def generate_partial_graph(text: str): | |
| triplet_extractor = pipeline( | |
| 'text2text-generation', | |
| model='Babelscape/rebel-large', | |
| tokenizer='Babelscape/rebel-large' | |
| ) | |
| triples = triplet_extractor( | |
| text, | |
| return_tensors=True, | |
| return_text=False) | |
| if len(triples) == 0: | |
| return [] | |
| a = [triples[0]["generated_token_ids"]] | |
| extracted_text = triplet_extractor.tokenizer.batch_decode(a) | |
| extracted_triplets = extract_triplets(extracted_text[0]) | |
| return extracted_triplets | |
| def extract_triplets(text): | |
| """ | |
| Function to parse the generated text and extract the triplets | |
| """ | |
| triplets = [] | |
| relation, subject, relation, object_ = '', '', '', '' | |
| text = text.strip() | |
| current = 'x' | |
| for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split(): | |
| if token == "<triplet>": | |
| current = 't' | |
| if relation != '': | |
| triplets.append( | |
| {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()}) | |
| relation = '' | |
| subject = '' | |
| elif token == "<subj>": | |
| current = 's' | |
| if relation != '': | |
| triplets.append( | |
| {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()}) | |
| object_ = '' | |
| elif token == "<obj>": | |
| current = 'o' | |
| relation = '' | |
| else: | |
| if current == 't': | |
| subject += ' ' + token | |
| elif current == 's': | |
| object_ += ' ' + token | |
| elif current == 'o': | |
| relation += ' ' + token | |
| if subject != '' and relation != '' and object_ != '': | |
| triplets.append( | |
| {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()}) | |
| return triplets | |
| if __name__ == "__main__": | |
| generate_knowledge_graph( | |
| ["The dog is happy", "The cat is sad"], "test.html") | |
