sethiuss's picture
Create app.py
10cf5b6
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch
import pandas as pd
import streamlit as st
import pickle
st.title('Entity Extraction from any text')
# Form
with st.form(key='form_parameters')
#%%
# adding the text that will show in the text box as default
default_value = "Let's have a machine extract entities form any text"
sent = st.text_area("Text", default_value, height = 275)
max_length = st.sidebar.slider("Max Length", min_value = 10, max_value=30)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.9)
num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=5, value=1, step=1)
#%%
#Relation Extraction By End-to-end Language generation (REBEL)
#linearization approach and a reframing of Relation Extraction as a seq2seq task.
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
#%%
#Parse strings generated by REBEL and transform them into triplets
# e.g. ("Seth, eats, In-n-Out" OR "Billy, lives, California")
def extract_relations_from_model_output(text):
relations = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
for token in text_replaced.split():
if token == "<triplet>":
current = 't'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
relations.append({
'head': subject.strip(), #Subject of relation "Seth"
'type': relation.strip(), #Relation e.g. "eats at"
'tail': object_.strip() #Object of relation "In-n-Out"
})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
#%%
class NET():
def __init__(self):
self.relations = []
def add_entity(self, e):
self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
def are_relations_equal(self, r1, r2):
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
def exists_relation(self, r1):
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
def merge_relations(self, r1):
r2 = [r for r in self.relations
if self.are_relations_equal(r1, r)][0]
spans_to_add = [span for span in r1["meta"]["spans"]
if span not in r2["meta"]["spans"]]
r2["meta"]["spans"] += spans_to_add
def add_relation(self, r):
if not self.exists_relation(r):
self.relations.append(r)
else:
self.merge_relations(r)
def print(self):
print("Relations:")
for r in self.relations:
print(f" {r}")
def from_text_to_net(text, span_length=128, verbose=False):
# tokenize whole text
inputs = tokenizer([text], return_tensors="pt")
# compute span boundaries
num_tokens = len(inputs["input_ids"][0])
if verbose:
print(f"Input has {num_tokens} tokens")
num_spans = math.ceil(num_tokens / span_length)
if verbose:
print(f"Input has {num_spans} spans")
overlap = math.ceil((num_spans * span_length - num_tokens) /
max(num_spans - 1, 1))
spans_boundaries = []
start = 0
for i in range(num_spans):
spans_boundaries.append([start + span_length * i,
start + span_length * (i + 1)])
start -= overlap
if verbose:
print(f"Span boundaries are {spans_boundaries}")
# transform input with spans
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
inputs = {
"input_ids": torch.stack(tensor_ids),
"attention_mask": torch.stack(tensor_masks)
}
# generate relations
num_return_sequences = 3
gen_kwargs = {
"max_length": 256,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": num_return_sequences
}
generated_tokens = model.generate(
**inputs,
**gen_kwargs,
)
# decode relations
decoded_preds = tokenizer.batch_decode(generated_tokens,
skip_special_tokens=False)
# create net
net = NET()
i = 0
for sentence_pred in decoded_preds:
current_span_index = i // num_return_sequences
relations = extract_relations_from_model_output(sentence_pred)
for relation in relations:
relation["meta"] = {
"spans": [spans_boundaries[current_span_index]]
}
net.add_relation(relation)
i += 1
return net