import os import numpy as np import streamlit as st from elasticsearch import Elasticsearch from embedders.LatinBERT import LatinBERT from embedders.labse import LaBSE if 'models' not in st.session_state: st.session_state['models'] = dict( LaBSE=LaBSE(), LatinBERT=LatinBERT(bertPath="./embedders/latin_bert/latin_bert", tokenizerPath="./embedders/tokenizer/latin.subword.encoder") ) verify_certs=True host = os.environ["ELASTIC_HOST"] user_pass = os.environ["ELASTIC_AUTH"].split(":") es = Elasticsearch(host, basic_auth=user_pass, verify_certs=verify_certs) def searchCloseSentence(document, startNumber, numCloseSentence=3): queryPrevious = { "bool": { "must": [{ "term": { "document": document } }, { "range": { "number": { "gte": startNumber - numCloseSentence, "lt": startNumber, } } } ] } } queryNext = { "bool": { "must": [{ "term": { "document": document } }, { "range": { "number": { "lte": startNumber+3, "gt": startNumber, } } } ] } } previous = es.search( index="sentences", query=queryPrevious ) nexts = es.search( index="sentences", query=queryNext ) previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"]) previous_context = "".join([r["_source"]["sentence"] for r in previous_hits]) subsequent_hits = sorted(nexts["hits"]["hits"], key=lambda e: e["_source"]["number"]) subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits]) document_name_results = es.search( index="documents", query={ "bool": { "must": [{ "term": { "id": document } } ] } } ) document_name_data = document_name_results["hits"]["hits"][0]["_source"] document_name = f"{document_name_data['title']} - {document_name_data['author']}" return document_name, previous_context, subsequent_context def prepareResults(results): results = results['hits']['hits'] #string_results = [] for sentence in results: text = sentence['_source']['sentence'] score = sentence['_score'] document = sentence['_source']['document'] number = sentence['_source']['number'] document_name, previous_context, subsequent_context = searchCloseSentence(document, number, 3) string_result = f"#### {document_name} (score: {score:.2f})\n{previous_context} **{text}** {subsequent_context}" #string_results.append(string_result) results_placeholder.markdown(string_result) #return string_results def search(): if query == "": return results_placeholder.markdown(f"Searching with {model_name} query={query}") status_indicator.write(f"Computing query embeddings...") query_vector = None embeddingType = None if model_name in ["LaBSE", "LatinBERT"]: query_vector = st.session_state['models'][model_name](query)[0, :].numpy().tolist() embeddingType = "labse_embedding" if model_name == "LaBSE" else "latinBERT_embedding" elif model_name in ["LaBSE-LatinBERT-Mean","LaBSE-LatinBERT-CONCAT"]: query_vector_labse = st.session_state['models']['LaBSE'](query)[0, :].numpy().tolist() query_vector_latinBERT = st.session_state['models']['LatinBERT'](query)[0, :].numpy().tolist() if model_name == "LaBSE-LatinBERT-Mean": query_vector = np.mean([query_vector_labse, query_vector_latinBERT], axis=0).tolist() embeddingType = "mean_embedding" elif model_name == "LaBSE-LatinBERT-CONCAT": query_vector = query_vector_latinBERT + query_vector_labse embeddingType = "concat_embedding" script = { "source": f"cosineSimilarity(params.query_vector, '{embeddingType}') + 1.0", "params": {"query_vector": query_vector} } status_indicator.write(f"Preparing the script for search...") results = es.search( index='sentences', query={ "script_score": { "query": {"match_all": {}}, "script": script } }, size=limit ) status_indicator.write(f"Prettifying the results ...") prepareResults(results) st.header("Serica Intelligent Search") st.write("This is a fork of this repo(https://huggingface.co/spaces/galatolo/serica-intelligent-search)") st.write("Perform an intelligent search using a Sentence Embedding Transformer model on the SERICA database") model_name = st.selectbox("Model", ["LaBSE", "LatinBERT", "LaBSE-LatinBERT-Mean", "LaBSE-LatinBERT-CONCAT"]) limit = st.number_input("Number of results (sentences) ", 25) query = st.text_input("Query", value="") status_indicator = st.empty() do_search = st.button("Search") results_placeholder = st.container() if do_search: search() #do_search(model_name, query, limit, results_placeholder, status_indicator)