blacksquadece commited on
Commit
5f141a5
1 Parent(s): 333a749

Create rebel.py

Browse files
Files changed (1) hide show
  1. rebel.py +122 -0
rebel.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from transformers import pipeline
3
+ from pyvis.network import Network
4
+ from functools import lru_cache
5
+ import spacy
6
+ from spacy import displacy
7
+
8
+
9
+ DEFAULT_LABEL_COLORS = {
10
+ "ORG": "#7aecec",
11
+ "PRODUCT": "#bfeeb7",
12
+ "GPE": "#feca74",
13
+ "LOC": "#ff9561",
14
+ "PERSON": "#aa9cfc",
15
+ "NORP": "#c887fb",
16
+ "FACILITY": "#9cc9cc",
17
+ "EVENT": "#ffeb80",
18
+ "LAW": "#ff8197",
19
+ "LANGUAGE": "#ff8197",
20
+ "WORK_OF_ART": "#f0d0ff",
21
+ "DATE": "#bfe1d9",
22
+ "TIME": "#bfe1d9",
23
+ "MONEY": "#e4e7d2",
24
+ "QUANTITY": "#e4e7d2",
25
+ "ORDINAL": "#e4e7d2",
26
+ "CARDINAL": "#e4e7d2",
27
+ "PERCENT": "#e4e7d2",
28
+ }
29
+
30
+ def generate_knowledge_graph(texts: List[str], filename: str):
31
+ nlp = spacy.load("en_core_web_sm")
32
+ doc = nlp("\n".join(texts).lower())
33
+ NERs = [ent.text for ent in doc.ents]
34
+ NER_types = [ent.label_ for ent in doc.ents]
35
+
36
+ triplets = []
37
+ for triplet in texts:
38
+ triplets.extend(generate_partial_graph(triplet))
39
+ heads = [ t["head"].lower() for t in triplets]
40
+ tails = [ t["tail"].lower() for t in triplets]
41
+
42
+ nodes = list(set(heads + tails))
43
+ net = Network(directed=True, width="700px", height="700px")
44
+
45
+ for n in nodes:
46
+ if n in NERs:
47
+ NER_type = NER_types[NERs.index(n)]
48
+ if NER_type in NER_types:
49
+ if NER_type in DEFAULT_LABEL_COLORS.keys():
50
+ color = DEFAULT_LABEL_COLORS[NER_type]
51
+ else:
52
+ color = "#666666"
53
+ net.add_node(n, title=NER_type, shape="circle", color=color)
54
+ else:
55
+ net.add_node(n, shape="circle")
56
+ else:
57
+ net.add_node(n, shape="circle")
58
+
59
+ unique_triplets = set()
60
+ stringify_trip = lambda x : x["tail"] + x["head"] + x["type"].lower()
61
+ for triplet in triplets:
62
+ if stringify_trip(triplet) not in unique_triplets:
63
+ net.add_edge(triplet["head"].lower(), triplet["tail"].lower(),
64
+ title=triplet["type"], label=triplet["type"])
65
+ unique_triplets.add(stringify_trip(triplet))
66
+
67
+ net.repulsion(
68
+ node_distance=200,
69
+ central_gravity=0.2,
70
+ spring_length=200,
71
+ spring_strength=0.05,
72
+ damping=0.09
73
+ )
74
+ net.set_edge_smooth('dynamic')
75
+ net.show(filename)
76
+ return nodes
77
+
78
+
79
+ @lru_cache(maxsize=16)
80
+ def generate_partial_graph(text: str):
81
+ triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
82
+ a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
83
+ extracted_text = triplet_extractor.tokenizer.batch_decode(a)
84
+ extracted_triplets = extract_triplets(extracted_text[0])
85
+ return extracted_triplets
86
+
87
+
88
+ def extract_triplets(text):
89
+ """
90
+ Function to parse the generated text and extract the triplets
91
+ """
92
+ triplets = []
93
+ relation, subject, relation, object_ = '', '', '', ''
94
+ text = text.strip()
95
+ current = 'x'
96
+ for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
97
+ if token == "<triplet>":
98
+ current = 't'
99
+ if relation != '':
100
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
101
+ relation = ''
102
+ subject = ''
103
+ elif token == "<subj>":
104
+ current = 's'
105
+ if relation != '':
106
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
107
+ object_ = ''
108
+ elif token == "<obj>":
109
+ current = 'o'
110
+ relation = ''
111
+ else:
112
+ if current == 't':
113
+ subject += ' ' + token
114
+ elif current == 's':
115
+ object_ += ' ' + token
116
+ elif current == 'o':
117
+ relation += ' ' + token
118
+ if subject != '' and relation != '' and object_ != '':
119
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
120
+
121
+ return triplets
122
+