sethiuss commited on
Commit
10cf5b6
1 Parent(s): 166bcfc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import math
3
+ import torch
4
+ import pandas as pd
5
+ import streamlit as st
6
+ import pickle
7
+ st.title('Entity Extraction from any text')
8
+
9
+ # Form
10
+ with st.form(key='form_parameters')
11
+ #%%
12
+
13
+ # adding the text that will show in the text box as default
14
+ default_value = "Let's have a machine extract entities form any text"
15
+
16
+ sent = st.text_area("Text", default_value, height = 275)
17
+ max_length = st.sidebar.slider("Max Length", min_value = 10, max_value=30)
18
+ temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
19
+ top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
20
+ top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.9)
21
+ num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=5, value=1, step=1)
22
+
23
+
24
+
25
+ #%%
26
+
27
+ #Relation Extraction By End-to-end Language generation (REBEL)
28
+ #linearization approach and a reframing of Relation Extraction as a seq2seq task.
29
+
30
+ # Load model and tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
32
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
33
+
34
+ #%%
35
+
36
+ #Parse strings generated by REBEL and transform them into triplets
37
+ # e.g. ("Seth, eats, In-n-Out" OR "Billy, lives, California")
38
+
39
+ def extract_relations_from_model_output(text):
40
+ relations = []
41
+ relation, subject, relation, object_ = '', '', '', ''
42
+ text = text.strip()
43
+ current = 'x'
44
+ text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
45
+ for token in text_replaced.split():
46
+ if token == "<triplet>":
47
+ current = 't'
48
+ if relation != '':
49
+ relations.append({
50
+ 'head': subject.strip(),
51
+ 'type': relation.strip(),
52
+ 'tail': object_.strip()
53
+ })
54
+ relation = ''
55
+ subject = ''
56
+ elif token == "<subj>":
57
+ current = 's'
58
+ if relation != '':
59
+ relations.append({
60
+ 'head': subject.strip(), #Subject of relation "Seth"
61
+ 'type': relation.strip(), #Relation e.g. "eats at"
62
+ 'tail': object_.strip() #Object of relation "In-n-Out"
63
+ })
64
+ object_ = ''
65
+ elif token == "<obj>":
66
+ current = 'o'
67
+ relation = ''
68
+ else:
69
+ if current == 't':
70
+ subject += ' ' + token
71
+ elif current == 's':
72
+ object_ += ' ' + token
73
+ elif current == 'o':
74
+ relation += ' ' + token
75
+ if subject != '' and relation != '' and object_ != '':
76
+ relations.append({
77
+ 'head': subject.strip(),
78
+ 'type': relation.strip(),
79
+ 'tail': object_.strip()
80
+ })
81
+ return relations
82
+
83
+
84
+ #%%
85
+
86
+
87
+ class NET():
88
+ def __init__(self):
89
+ self.relations = []
90
+
91
+ def add_entity(self, e):
92
+ self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
93
+
94
+ def are_relations_equal(self, r1, r2):
95
+ return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
96
+
97
+ def exists_relation(self, r1):
98
+ return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
99
+
100
+ def merge_relations(self, r1):
101
+ r2 = [r for r in self.relations
102
+ if self.are_relations_equal(r1, r)][0]
103
+ spans_to_add = [span for span in r1["meta"]["spans"]
104
+ if span not in r2["meta"]["spans"]]
105
+ r2["meta"]["spans"] += spans_to_add
106
+
107
+ def add_relation(self, r):
108
+ if not self.exists_relation(r):
109
+ self.relations.append(r)
110
+ else:
111
+ self.merge_relations(r)
112
+
113
+ def print(self):
114
+ print("Relations:")
115
+ for r in self.relations:
116
+ print(f" {r}")
117
+
118
+ def from_text_to_net(text, span_length=128, verbose=False):
119
+ # tokenize whole text
120
+ inputs = tokenizer([text], return_tensors="pt")
121
+
122
+ # compute span boundaries
123
+ num_tokens = len(inputs["input_ids"][0])
124
+ if verbose:
125
+ print(f"Input has {num_tokens} tokens")
126
+ num_spans = math.ceil(num_tokens / span_length)
127
+ if verbose:
128
+ print(f"Input has {num_spans} spans")
129
+ overlap = math.ceil((num_spans * span_length - num_tokens) /
130
+ max(num_spans - 1, 1))
131
+ spans_boundaries = []
132
+ start = 0
133
+ for i in range(num_spans):
134
+ spans_boundaries.append([start + span_length * i,
135
+ start + span_length * (i + 1)])
136
+ start -= overlap
137
+ if verbose:
138
+ print(f"Span boundaries are {spans_boundaries}")
139
+
140
+ # transform input with spans
141
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
142
+ for boundary in spans_boundaries]
143
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
144
+ for boundary in spans_boundaries]
145
+ inputs = {
146
+ "input_ids": torch.stack(tensor_ids),
147
+ "attention_mask": torch.stack(tensor_masks)
148
+ }
149
+
150
+ # generate relations
151
+ num_return_sequences = 3
152
+ gen_kwargs = {
153
+ "max_length": 256,
154
+ "length_penalty": 0,
155
+ "num_beams": 3,
156
+ "num_return_sequences": num_return_sequences
157
+ }
158
+ generated_tokens = model.generate(
159
+ **inputs,
160
+ **gen_kwargs,
161
+ )
162
+
163
+ # decode relations
164
+ decoded_preds = tokenizer.batch_decode(generated_tokens,
165
+ skip_special_tokens=False)
166
+
167
+ # create net
168
+ net = NET()
169
+ i = 0
170
+ for sentence_pred in decoded_preds:
171
+ current_span_index = i // num_return_sequences
172
+ relations = extract_relations_from_model_output(sentence_pred)
173
+ for relation in relations:
174
+ relation["meta"] = {
175
+ "spans": [spans_boundaries[current_span_index]]
176
+ }
177
+ net.add_relation(relation)
178
+ i += 1
179
+
180
+ return net