|
|
|
|
|
import os |
|
|
|
import numpy as np |
|
import streamlit as st |
|
from elasticsearch import Elasticsearch |
|
|
|
from embedders.LatinBERT import LatinBERT |
|
from embedders.labse import LaBSE |
|
|
|
|
|
if 'models' not in st.session_state: |
|
st.session_state['models'] = dict( |
|
LaBSE=LaBSE(), |
|
LatinBERT=LatinBERT(bertPath="./embedders/latin_bert/latin_bert", tokenizerPath="./embedders/tokenizer/latin.subword.encoder") |
|
) |
|
|
|
verify_certs=True |
|
|
|
host = os.environ["ELASTIC_HOST"] |
|
user_pass = os.environ["ELASTIC_AUTH"].split(":") |
|
es = Elasticsearch(host, basic_auth=user_pass, verify_certs=verify_certs) |
|
|
|
|
|
def searchCloseSentence(document, startNumber, numCloseSentence=3): |
|
|
|
queryPrevious = { |
|
"bool": { |
|
"must": [{ |
|
"term": { |
|
"document": document |
|
} |
|
}, { |
|
"range": { |
|
"number": { |
|
"gte": startNumber - numCloseSentence, |
|
"lt": startNumber, |
|
} |
|
} |
|
} |
|
] |
|
} |
|
} |
|
|
|
queryNext = { |
|
"bool": { |
|
"must": [{ |
|
"term": { |
|
"document": document |
|
} |
|
}, { |
|
"range": { |
|
"number": { |
|
"lte": startNumber+3, |
|
"gt": startNumber, |
|
} |
|
} |
|
} |
|
] |
|
} |
|
} |
|
|
|
previous = es.search( |
|
index="sentences", |
|
query=queryPrevious |
|
) |
|
nexts = es.search( |
|
index="sentences", |
|
query=queryNext |
|
) |
|
previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"]) |
|
previous_context = "".join([r["_source"]["sentence"] for r in previous_hits]) |
|
|
|
subsequent_hits = sorted(nexts["hits"]["hits"], key=lambda e: e["_source"]["number"]) |
|
subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits]) |
|
|
|
document_name_results = es.search( |
|
index="documents", |
|
query={ |
|
"bool": { |
|
"must": [{ |
|
"term": { |
|
"id": document |
|
} |
|
} |
|
] |
|
} |
|
} |
|
) |
|
|
|
document_name_data = document_name_results["hits"]["hits"][0]["_source"] |
|
document_name = f"{document_name_data['title']} - {document_name_data['author']}" |
|
|
|
return document_name, previous_context, subsequent_context |
|
|
|
def prepareResults(results): |
|
results = results['hits']['hits'] |
|
|
|
|
|
for sentence in results: |
|
text = sentence['_source']['sentence'] |
|
score = sentence['_score'] |
|
document = sentence['_source']['document'] |
|
number = sentence['_source']['number'] |
|
document_name, previous_context, subsequent_context = searchCloseSentence(document, number, 3) |
|
string_result = f"#### {document_name} (score: {score:.2f})\n{previous_context} **{text}** {subsequent_context}" |
|
|
|
results_placeholder.markdown(string_result) |
|
|
|
def search(): |
|
if query == "": |
|
return |
|
results_placeholder.markdown(f"Searching with {model_name} query={query}") |
|
status_indicator.write(f"Computing query embeddings...") |
|
|
|
query_vector = None |
|
embeddingType = None |
|
if model_name in ["LaBSE", "LatinBERT"]: |
|
query_vector = st.session_state['models'][model_name](query)[0, :].numpy().tolist() |
|
embeddingType = "labse_embedding" if model_name == "LaBSE" else "latinBERT_embedding" |
|
elif model_name in ["LaBSE-LatinBERT-Mean","LaBSE-LatinBERT-CONCAT"]: |
|
query_vector_labse = st.session_state['models']['LaBSE'](query)[0, :].numpy().tolist() |
|
query_vector_latinBERT = st.session_state['models']['LatinBERT'](query)[0, :].numpy().tolist() |
|
|
|
if model_name == "LaBSE-LatinBERT-Mean": |
|
query_vector = np.mean([query_vector_labse, query_vector_latinBERT], axis=0).tolist() |
|
embeddingType = "mean_embedding" |
|
elif model_name == "LaBSE-LatinBERT-CONCAT": |
|
query_vector = query_vector_latinBERT + query_vector_labse |
|
embeddingType = "concat_embedding" |
|
|
|
|
|
script = { |
|
"source": f"cosineSimilarity(params.query_vector, '{embeddingType}') + 1.0", |
|
"params": {"query_vector": query_vector} |
|
} |
|
status_indicator.write(f"Preparing the script for search...") |
|
results = es.search( |
|
index='sentences', |
|
query={ |
|
"script_score": { |
|
"query": {"match_all": {}}, |
|
"script": script |
|
|
|
} |
|
|
|
}, |
|
size=limit |
|
) |
|
status_indicator.write(f"Prettifying the results ...") |
|
prepareResults(results) |
|
|
|
st.header("Serica Intelligent Search") |
|
st.write("This is a fork of this repo(https://huggingface.co/spaces/galatolo/serica-intelligent-search)") |
|
st.write("Perform an intelligent search using a Sentence Embedding Transformer model on the SERICA database") |
|
model_name = st.selectbox("Model", ["LaBSE", "LatinBERT", "LaBSE-LatinBERT-Mean", "LaBSE-LatinBERT-CONCAT"]) |
|
limit = st.number_input("Number of results (sentences) ", 25) |
|
query = st.text_input("Query", value="") |
|
status_indicator = st.empty() |
|
|
|
do_search = st.button("Search") |
|
results_placeholder = st.container() |
|
|
|
if do_search: |
|
search() |
|
|
|
|
|
|