COVID_NLI / utils.py
hitz02's picture
Upload 3 files
07ab211
import pandas as pd
import numpy as np
import pickle
import glob
import json
from pandas.io.json import json_normalize
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
def get_full_sentence(spacy_nlp, para_text, start_index, end_index):
"""
Returns the relative sentence of original text,
given a specific paragraph (body text).
"""
sent_start = 0
sent_end = len(para_text)
for sent in spacy_nlp(para_text).sents:
if (sent.start_char <= start_index) and (sent.end_char >= start_index):
sent_start = sent.start_char
if (sent.start_char <= end_index) and (sent.end_char >= end_index):
sent_end = sent.end_char
sentence = para_text[sent_start:sent_end + 1]
return sentence
def fetch_stage1(query, model, list_of_articles):
"""
Compare all the articles' abstract content with each query
"""
# Encode queries
query_embedding = model.encode([query])[0]
all_abs_distances = []
for idx_of_article,article in enumerate(list_of_articles):
if article:
distances = []
cdists = scipy.spatial.distance.cdist([query_embedding], np.vstack(article), "cosine").reshape(-1,1)
for idx,sentence in enumerate(article):
distances.append((idx, 1 - cdists[idx][0]))
results = sorted(distances, key=lambda x: x[1], reverse=True)
if results:
all_abs_distances.append((idx_of_article, results[0][0], results[0][1]))
results = sorted(all_abs_distances, key=lambda x: x[2], reverse=True)
return query_embedding, results
def fetch_stage2(results, model, embeddings, query_embedding):
"""
Take the 20 most similar articles, based on the relevant abstracts and
compare all the body texts content to the query
"""
all_text_distances = []
for top in results[0:20]:
article_idx = top[0]
body_texts = [text[0] for text in embeddings[article_idx][2]]
body_text_embeddings = model.encode(body_texts, show_progress_bar=False)
# body_text_distances = []
# for text_idx,text in enumerate(embeddings[article_idx][2]):
qbody = scipy.spatial.distance.cdist([query_embedding],
np.vstack(body_text_embeddings),
"cosine").reshape(-1,1)
body_text_distances = [(idx, 1-dist[0]) for idx,dist in enumerate(qbody)]
# for text_idx,text in enumerate(body_texts):
# # Encode only the body texts of 20 best articles
# # body_text_embedding = model.encode(text, show_progress_bar=False)
# body_text_distances.append(((text_idx,
# (1 - ([0]))
# )))
results = sorted(body_text_distances, key=lambda x: x[1], reverse=True)
if results:
all_text_distances.append((article_idx, results[0][0], results[0][1]))
results = sorted(all_text_distances, key=lambda x: x[2], reverse=True)
return results
def fetch_stage3(results, query, embeddings, comprehension_model, spacy_nlp):
"""
For the top 20 retrieved paragraphs in the document,
answer will be comprehended on each paragraph using the model.
"""
answers = []
# contxt = [embeddings[top_text[0]][2][top_text[1]][0] for top_text in results[0:20]]
for top_text in results[0:20]:
article_idx = top_text[0]
body_text_idx = top_text[1]
query_ = {"context": embeddings[article_idx][2][body_text_idx][0], "question": query}
pred = comprehension_model(query_, topk=1, show_progress_bar=False)
# If there is any answer
if pred["answer"] and round(pred["score"], 4) > 0:
# Take the suitable sentence from the paragraph
sent = get_full_sentence(spacy_nlp, query_['context'], pred["start"], pred["end"])
answers.append((article_idx, round(pred["score"], 4), sent))
results = sorted(answers, key=lambda x: x[1], reverse=True)
return results