from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import math import torch import pandas as pd import streamlit as st import pickle st.title('Entity Extraction from any text') # Form with st.form(key='form_parameters') #%% # adding the text that will show in the text box as default default_value = "Let's have a machine extract entities form any text" sent = st.text_area("Text", default_value, height = 275) max_length = st.sidebar.slider("Max Length", min_value = 10, max_value=30) temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05) top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0) top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.9) num_return_sequences = st.sidebar.number_input('Number of Return Sequences', min_value=1, max_value=5, value=1, step=1) #%% #Relation Extraction By End-to-end Language generation (REBEL) #linearization approach and a reframing of Relation Extraction as a seq2seq task. # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") #%% #Parse strings generated by REBEL and transform them into triplets # e.g. ("Seth, eats, In-n-Out" OR "Billy, lives, California") def extract_relations_from_model_output(text): relations = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' text_replaced = text.replace("", "").replace("", "").replace("", "") for token in text_replaced.split(): if token == "": current = 't' if relation != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) relation = '' subject = '' elif token == "": current = 's' if relation != '': relations.append({ 'head': subject.strip(), #Subject of relation "Seth" 'type': relation.strip(), #Relation e.g. "eats at" 'tail': object_.strip() #Object of relation "In-n-Out" }) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object_ != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) return relations #%% class NET(): def __init__(self): self.relations = [] def add_entity(self, e): self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"} def are_relations_equal(self, r1, r2): return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) def exists_relation(self, r1): return any(self.are_relations_equal(r1, r2) for r2 in self.relations) def merge_relations(self, r1): r2 = [r for r in self.relations if self.are_relations_equal(r1, r)][0] spans_to_add = [span for span in r1["meta"]["spans"] if span not in r2["meta"]["spans"]] r2["meta"]["spans"] += spans_to_add def add_relation(self, r): if not self.exists_relation(r): self.relations.append(r) else: self.merge_relations(r) def print(self): print("Relations:") for r in self.relations: print(f" {r}") def from_text_to_net(text, span_length=128, verbose=False): # tokenize whole text inputs = tokenizer([text], return_tensors="pt") # compute span boundaries num_tokens = len(inputs["input_ids"][0]) if verbose: print(f"Input has {num_tokens} tokens") num_spans = math.ceil(num_tokens / span_length) if verbose: print(f"Input has {num_spans} spans") overlap = math.ceil((num_spans * span_length - num_tokens) / max(num_spans - 1, 1)) spans_boundaries = [] start = 0 for i in range(num_spans): spans_boundaries.append([start + span_length * i, start + span_length * (i + 1)]) start -= overlap if verbose: print(f"Span boundaries are {spans_boundaries}") # transform input with spans tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] inputs = { "input_ids": torch.stack(tensor_ids), "attention_mask": torch.stack(tensor_masks) } # generate relations num_return_sequences = 3 gen_kwargs = { "max_length": 256, "length_penalty": 0, "num_beams": 3, "num_return_sequences": num_return_sequences } generated_tokens = model.generate( **inputs, **gen_kwargs, ) # decode relations decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) # create net net = NET() i = 0 for sentence_pred in decoded_preds: current_span_index = i // num_return_sequences relations = extract_relations_from_model_output(sentence_pred) for relation in relations: relation["meta"] = { "spans": [spans_boundaries[current_span_index]] } net.add_relation(relation) i += 1 return net