Spaces:
Sleeping
Sleeping
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import math | |
import torch | |
import pandas as pd | |
import streamlit as st | |
import pickle | |
st.title('Entity Extraction from any text') | |
# Form | |
with st.form(key='form_parameters') | |
#%% | |
# adding the text that will show in the text box as default | |
default_value = "Let's have a machine extract entities form any text" | |
sent = st.text_area("Text", default_value, height = 275) | |
max_length = st.sidebar.slider("Max Length", min_value = 10, max_value=30) | |
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05) | |
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0) | |
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.9) | |
num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=5, value=1, step=1) | |
#%% | |
#Relation Extraction By End-to-end Language generation (REBEL) | |
#linearization approach and a reframing of Relation Extraction as a seq2seq task. | |
# Load model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") | |
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") | |
#%% | |
#Parse strings generated by REBEL and transform them into triplets | |
# e.g. ("Seth, eats, In-n-Out" OR "Billy, lives, California") | |
def extract_relations_from_model_output(text): | |
relations = [] | |
relation, subject, relation, object_ = '', '', '', '' | |
text = text.strip() | |
current = 'x' | |
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "") | |
for token in text_replaced.split(): | |
if token == "<triplet>": | |
current = 't' | |
if relation != '': | |
relations.append({ | |
'head': subject.strip(), | |
'type': relation.strip(), | |
'tail': object_.strip() | |
}) | |
relation = '' | |
subject = '' | |
elif token == "<subj>": | |
current = 's' | |
if relation != '': | |
relations.append({ | |
'head': subject.strip(), #Subject of relation "Seth" | |
'type': relation.strip(), #Relation e.g. "eats at" | |
'tail': object_.strip() #Object of relation "In-n-Out" | |
}) | |
object_ = '' | |
elif token == "<obj>": | |
current = 'o' | |
relation = '' | |
else: | |
if current == 't': | |
subject += ' ' + token | |
elif current == 's': | |
object_ += ' ' + token | |
elif current == 'o': | |
relation += ' ' + token | |
if subject != '' and relation != '' and object_ != '': | |
relations.append({ | |
'head': subject.strip(), | |
'type': relation.strip(), | |
'tail': object_.strip() | |
}) | |
return relations | |
#%% | |
class NET(): | |
def __init__(self): | |
self.relations = [] | |
def add_entity(self, e): | |
self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"} | |
def are_relations_equal(self, r1, r2): | |
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) | |
def exists_relation(self, r1): | |
return any(self.are_relations_equal(r1, r2) for r2 in self.relations) | |
def merge_relations(self, r1): | |
r2 = [r for r in self.relations | |
if self.are_relations_equal(r1, r)][0] | |
spans_to_add = [span for span in r1["meta"]["spans"] | |
if span not in r2["meta"]["spans"]] | |
r2["meta"]["spans"] += spans_to_add | |
def add_relation(self, r): | |
if not self.exists_relation(r): | |
self.relations.append(r) | |
else: | |
self.merge_relations(r) | |
def print(self): | |
print("Relations:") | |
for r in self.relations: | |
print(f" {r}") | |
def from_text_to_net(text, span_length=128, verbose=False): | |
# tokenize whole text | |
inputs = tokenizer([text], return_tensors="pt") | |
# compute span boundaries | |
num_tokens = len(inputs["input_ids"][0]) | |
if verbose: | |
print(f"Input has {num_tokens} tokens") | |
num_spans = math.ceil(num_tokens / span_length) | |
if verbose: | |
print(f"Input has {num_spans} spans") | |
overlap = math.ceil((num_spans * span_length - num_tokens) / | |
max(num_spans - 1, 1)) | |
spans_boundaries = [] | |
start = 0 | |
for i in range(num_spans): | |
spans_boundaries.append([start + span_length * i, | |
start + span_length * (i + 1)]) | |
start -= overlap | |
if verbose: | |
print(f"Span boundaries are {spans_boundaries}") | |
# transform input with spans | |
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] | |
for boundary in spans_boundaries] | |
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] | |
for boundary in spans_boundaries] | |
inputs = { | |
"input_ids": torch.stack(tensor_ids), | |
"attention_mask": torch.stack(tensor_masks) | |
} | |
# generate relations | |
num_return_sequences = 3 | |
gen_kwargs = { | |
"max_length": 256, | |
"length_penalty": 0, | |
"num_beams": 3, | |
"num_return_sequences": num_return_sequences | |
} | |
generated_tokens = model.generate( | |
**inputs, | |
**gen_kwargs, | |
) | |
# decode relations | |
decoded_preds = tokenizer.batch_decode(generated_tokens, | |
skip_special_tokens=False) | |
# create net | |
net = NET() | |
i = 0 | |
for sentence_pred in decoded_preds: | |
current_span_index = i // num_return_sequences | |
relations = extract_relations_from_model_output(sentence_pred) | |
for relation in relations: | |
relation["meta"] = { | |
"spans": [spans_boundaries[current_span_index]] | |
} | |
net.add_relation(relation) | |
i += 1 | |
return net |