GabMartino's picture
little add
9a78960
import os
import numpy as np
import streamlit as st
from elasticsearch import Elasticsearch
from embedders.LatinBERT import LatinBERT
from embedders.labse import LaBSE
if 'models' not in st.session_state:
st.session_state['models'] = dict(
LaBSE=LaBSE(),
LatinBERT=LatinBERT(bertPath="./embedders/latin_bert/latin_bert", tokenizerPath="./embedders/tokenizer/latin.subword.encoder")
)
verify_certs=True
host = os.environ["ELASTIC_HOST"]
user_pass = os.environ["ELASTIC_AUTH"].split(":")
es = Elasticsearch(host, basic_auth=user_pass, verify_certs=verify_certs)
def searchCloseSentence(document, startNumber, numCloseSentence=3):
queryPrevious = {
"bool": {
"must": [{
"term": {
"document": document
}
}, {
"range": {
"number": {
"gte": startNumber - numCloseSentence,
"lt": startNumber,
}
}
}
]
}
}
queryNext = {
"bool": {
"must": [{
"term": {
"document": document
}
}, {
"range": {
"number": {
"lte": startNumber+3,
"gt": startNumber,
}
}
}
]
}
}
previous = es.search(
index="sentences",
query=queryPrevious
)
nexts = es.search(
index="sentences",
query=queryNext
)
previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"])
previous_context = "".join([r["_source"]["sentence"] for r in previous_hits])
subsequent_hits = sorted(nexts["hits"]["hits"], key=lambda e: e["_source"]["number"])
subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits])
document_name_results = es.search(
index="documents",
query={
"bool": {
"must": [{
"term": {
"id": document
}
}
]
}
}
)
document_name_data = document_name_results["hits"]["hits"][0]["_source"]
document_name = f"{document_name_data['title']} - {document_name_data['author']}"
return document_name, previous_context, subsequent_context
def prepareResults(results):
results = results['hits']['hits']
#string_results = []
for sentence in results:
text = sentence['_source']['sentence']
score = sentence['_score']
document = sentence['_source']['document']
number = sentence['_source']['number']
document_name, previous_context, subsequent_context = searchCloseSentence(document, number, 3)
string_result = f"#### {document_name} (score: {score:.2f})\n{previous_context} **{text}** {subsequent_context}"
#string_results.append(string_result)
results_placeholder.markdown(string_result)
#return string_results
def search():
if query == "":
return
results_placeholder.markdown(f"Searching with {model_name} query={query}")
status_indicator.write(f"Computing query embeddings...")
query_vector = None
embeddingType = None
if model_name in ["LaBSE", "LatinBERT"]:
query_vector = st.session_state['models'][model_name](query)[0, :].numpy().tolist()
embeddingType = "labse_embedding" if model_name == "LaBSE" else "latinBERT_embedding"
elif model_name in ["LaBSE-LatinBERT-Mean","LaBSE-LatinBERT-CONCAT"]:
query_vector_labse = st.session_state['models']['LaBSE'](query)[0, :].numpy().tolist()
query_vector_latinBERT = st.session_state['models']['LatinBERT'](query)[0, :].numpy().tolist()
if model_name == "LaBSE-LatinBERT-Mean":
query_vector = np.mean([query_vector_labse, query_vector_latinBERT], axis=0).tolist()
embeddingType = "mean_embedding"
elif model_name == "LaBSE-LatinBERT-CONCAT":
query_vector = query_vector_latinBERT + query_vector_labse
embeddingType = "concat_embedding"
script = {
"source": f"cosineSimilarity(params.query_vector, '{embeddingType}') + 1.0",
"params": {"query_vector": query_vector}
}
status_indicator.write(f"Preparing the script for search...")
results = es.search(
index='sentences',
query={
"script_score": {
"query": {"match_all": {}},
"script": script
}
},
size=limit
)
status_indicator.write(f"Prettifying the results ...")
prepareResults(results)
st.header("Serica Intelligent Search")
st.write("This is a fork of this repo(https://huggingface.co/spaces/galatolo/serica-intelligent-search)")
st.write("Perform an intelligent search using a Sentence Embedding Transformer model on the SERICA database")
model_name = st.selectbox("Model", ["LaBSE", "LatinBERT", "LaBSE-LatinBERT-Mean", "LaBSE-LatinBERT-CONCAT"])
limit = st.number_input("Number of results (sentences) ", 25)
query = st.text_input("Query", value="")
status_indicator = st.empty()
do_search = st.button("Search")
results_placeholder = st.container()
if do_search:
search()
#do_search(model_name, query, limit, results_placeholder, status_indicator)