text-to-kb / utils.py
fabiochiusano
first commit
0ba9aa2
from pyvis.network import Network
from GoogleNews import GoogleNews
from newspaper import Article, ArticleException
import math
import torch
from kb import KB
import pickle
def extract_relations_from_model_output(text):
relations = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
for token in text_replaced.split():
if token == "<triplet>":
current = 't'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
article_publish_date=None, verbose=False):
# tokenize whole text
inputs = tokenizer([text], return_tensors="pt")
# compute span boundaries
num_tokens = len(inputs["input_ids"][0])
if verbose:
print(f"Input has {num_tokens} tokens")
num_spans = math.ceil(num_tokens / span_length)
if verbose:
print(f"Input has {num_spans} spans")
overlap = math.ceil((num_spans * span_length - num_tokens) /
max(num_spans - 1, 1))
spans_boundaries = []
start = 0
for i in range(num_spans):
spans_boundaries.append([start + span_length * i,
start + span_length * (i + 1)])
start -= overlap
if verbose:
print(f"Span boundaries are {spans_boundaries}")
# transform input with spans
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
inputs = {
"input_ids": torch.stack(tensor_ids),
"attention_mask": torch.stack(tensor_masks)
}
# generate relations
num_return_sequences = 3
gen_kwargs = {
"max_length": 256,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": num_return_sequences
}
generated_tokens = model.generate(
**inputs,
**gen_kwargs,
)
# decode relations
decoded_preds = tokenizer.batch_decode(generated_tokens,
skip_special_tokens=False)
# create kb
kb = KB()
i = 0
for sentence_pred in decoded_preds:
current_span_index = i // num_return_sequences
relations = extract_relations_from_model_output(sentence_pred)
for relation in relations:
relation["meta"] = {
article_url: {
"spans": [spans_boundaries[current_span_index]]
}
}
kb.add_relation(relation, article_title, article_publish_date)
i += 1
return kb
def get_article(url):
article = Article(url)
article.download()
article.parse()
return article
def from_url_to_kb(url, model, tokenizer):
article = get_article(url)
config = {
"article_title": article.title,
"article_publish_date": article.publish_date
}
kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
return kb
def get_news_links(query, lang="en", region="US", pages=1):
googlenews = GoogleNews(lang=lang, region=region)
googlenews.search(query)
all_urls = []
for page in range(pages):
googlenews.get_page(page)
all_urls += googlenews.get_links()
return list(set(all_urls))
def from_urls_to_kb(urls, model, tokenizer, verbose=False):
kb = KB()
if verbose:
print(f"{len(urls)} links to visit")
for url in urls:
if verbose:
print(f"Visiting {url}...")
try:
kb_url = from_url_to_kb(url, model, tokenizer)
kb.merge_with_kb(kb_url)
except ArticleException:
if verbose:
print(f" Couldn't download article at url {url}")
return kb
def save_network_html(kb, filename="network.html"):
# create network
net = Network(directed=True, width="700px", height="700px")
# nodes
color_entity = "#00FF00"
for e in kb.entities:
net.add_node(e, shape="circle", color=color_entity)
# edges
for r in kb.relations:
net.add_edge(r["head"], r["tail"],
title=r["type"], label=r["type"])
# save network
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show(filename)
def save_kb(kb, filename):
with open(filename, "wb") as f:
pickle.dump(kb, f)
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == 'KB':
return KB
return super().find_class(module, name)
def load_kb(filename):
res = None
with open(filename, "rb") as f:
res = CustomUnpickler(f).load()
return res