autosumm / extractor /_utils.py
mhsvieira's picture
Pre-load models
78a71e8
raw
history blame
3.55 kB
import nmslib
import numpy as np
import streamlit as st
# import inflect
import torch
# p = inflect.engine()
class FewDocumentsError(Exception):
def __init__(self, documents, size, msg):
self.documents = documents
self.size = size
self.msg = msg
def __str__(self):
return repr(self.msg)
def document_extraction(dataset, query, keywords, min_document_size, min_just_one_paragraph_size):
# TODO: compare inflected forms
# word_in_text = lambda word, text: any([p.compare(word, w) for w in text.split()])
word_in_text = lambda word, text: word in set(text.split())
lower_dataset = [document.lower() for document in dataset]
lower_query = query.lower()
lower_keywords = [keyword.lower() for keyword in keywords]
documents = {}
documents['QUERY'] = [
dataset[lower_dataset.index(document)] for document in lower_dataset
if (word_in_text(lower_query, document))
and (len(document.split()) > min_document_size)
and any(len(paragraph.split()) > min_just_one_paragraph_size for paragraph in document.splitlines())
]
documents['AND'] = [
dataset[lower_dataset.index(document)] for document in lower_dataset
if all(word_in_text(keyword, document) for keyword in lower_keywords)
and (len(document.split()) > min_document_size)
and any(len(paragraph.split()) > min_just_one_paragraph_size for paragraph in document.splitlines())
]
documents['OR'] = [
dataset[lower_dataset.index(document)] for document in lower_dataset
if any(word_in_text(keyword, document) for keyword in lower_keywords)
and (len(document.split()) > min_document_size)
and any(len(paragraph.split()) > min_just_one_paragraph_size for paragraph in document.splitlines())
]
empty = {
'QUERY': len(documents['QUERY']) == 0,
'AND': len(documents['AND']) == 0,
'OR': len(documents['OR']) == 0
}
sizes = {
'QUERY': len(documents['QUERY']),
'AND': len(documents['AND']),
'OR': len(documents['OR'])
}
if all(empty.values()):
# TODO: throw error
st.info(empty.values())
st.warning(f'No document found for the query "{query}", please try with another query')
st.stop()
if sizes['QUERY'] >= 10:
extracted_documents = documents['QUERY']
elif sizes['AND'] >= 10:
extracted_documents = documents['AND']
elif sizes['OR'] >= 10:
extracted_documents = documents['OR']
else:
number_of_documents = sizes['OR']
raise FewDocumentsError(documents['OR'], number_of_documents,
f'Only {number_of_documents} documents found for the query "{query}"\n\
Please select continue to proceed with {number_of_documents} documents or try again with another query'
)
return extracted_documents, empty, sizes
def paragraph_extraction(documents, min_paragraph_size):
paragraphs = [
documents[i].splitlines()[j] for i in range(len(documents)) for j in range(len(documents[i].splitlines()))
if (len(documents[i].splitlines()[j].split()) > min_paragraph_size)
]
return paragraphs
def semantic_search(model, query, files, number_of_similar_files):
encoded_query = model.encode(query)
encoded_files = model.encode(files)
model_index = nmslib.init(method='hnsw', space='angulardist')
model_index.addDataPointBatch(encoded_files)
model_index.createIndex({'post': 2})
ids, distances = model_index.knnQuery(encoded_query, k=number_of_similar_files)
selected_files = [files[index] for index in ids]
distances = 180*distances/np.pi
return selected_files, distances;