Spaces:
Runtime error
Runtime error
import nmslib | |
import numpy as np | |
import streamlit as st | |
# import inflect | |
import torch | |
# p = inflect.engine() | |
class FewDocumentsError(Exception): | |
def __init__(self, documents, size, msg): | |
self.documents = documents | |
self.size = size | |
self.msg = msg | |
def __str__(self): | |
return repr(self.msg) | |
def document_extraction(dataset, query, keywords, min_document_size, min_just_one_paragraph_size): | |
# TODO: compare inflected forms | |
# word_in_text = lambda word, text: any([p.compare(word, w) for w in text.split()]) | |
word_in_text = lambda word, text: word in set(text.split()) | |
lower_dataset = [document.lower() for document in dataset] | |
lower_query = query.lower() | |
lower_keywords = [keyword.lower() for keyword in keywords] | |
documents = {} | |
documents['QUERY'] = [ | |
dataset[lower_dataset.index(document)] for document in lower_dataset | |
if (word_in_text(lower_query, document)) | |
and (len(document.split()) > min_document_size) | |
and any(len(paragraph.split()) > min_just_one_paragraph_size for paragraph in document.splitlines()) | |
] | |
documents['AND'] = [ | |
dataset[lower_dataset.index(document)] for document in lower_dataset | |
if all(word_in_text(keyword, document) for keyword in lower_keywords) | |
and (len(document.split()) > min_document_size) | |
and any(len(paragraph.split()) > min_just_one_paragraph_size for paragraph in document.splitlines()) | |
] | |
documents['OR'] = [ | |
dataset[lower_dataset.index(document)] for document in lower_dataset | |
if any(word_in_text(keyword, document) for keyword in lower_keywords) | |
and (len(document.split()) > min_document_size) | |
and any(len(paragraph.split()) > min_just_one_paragraph_size for paragraph in document.splitlines()) | |
] | |
empty = { | |
'QUERY': len(documents['QUERY']) == 0, | |
'AND': len(documents['AND']) == 0, | |
'OR': len(documents['OR']) == 0 | |
} | |
sizes = { | |
'QUERY': len(documents['QUERY']), | |
'AND': len(documents['AND']), | |
'OR': len(documents['OR']) | |
} | |
if all(empty.values()): | |
# TODO: throw error | |
st.info(empty.values()) | |
st.warning(f'No document found for the query "{query}", please try with another query') | |
st.stop() | |
if sizes['QUERY'] >= 10: | |
extracted_documents = documents['QUERY'] | |
elif sizes['AND'] >= 10: | |
extracted_documents = documents['AND'] | |
elif sizes['OR'] >= 10: | |
extracted_documents = documents['OR'] | |
else: | |
number_of_documents = sizes['OR'] | |
raise FewDocumentsError(documents['OR'], number_of_documents, | |
f'Only {number_of_documents} documents found for the query "{query}"\n\ | |
Please select continue to proceed with {number_of_documents} documents or try again with another query' | |
) | |
return extracted_documents, empty, sizes | |
def paragraph_extraction(documents, min_paragraph_size): | |
paragraphs = [ | |
documents[i].splitlines()[j] for i in range(len(documents)) for j in range(len(documents[i].splitlines())) | |
if (len(documents[i].splitlines()[j].split()) > min_paragraph_size) | |
] | |
return paragraphs | |
def semantic_search(model, query, files, number_of_similar_files): | |
encoded_query = model.encode(query) | |
encoded_files = model.encode(files) | |
model_index = nmslib.init(method='hnsw', space='angulardist') | |
model_index.addDataPointBatch(encoded_files) | |
model_index.createIndex({'post': 2}) | |
ids, distances = model_index.knnQuery(encoded_query, k=number_of_similar_files) | |
selected_files = [files[index] for index in ids] | |
distances = 180*distances/np.pi | |
return selected_files, distances; |