|
import os |
|
import streamlit as st |
|
from elasticsearch import Elasticsearch |
|
|
|
from embedders.labse import LaBSE |
|
|
|
def search(): |
|
status_indicator.write(f"Loading model {model_name} (it can take ~1 minute the first time)...") |
|
model = globals()[model_name]() |
|
|
|
status_indicator.write(f"Computing query embeddings...") |
|
query_vector = model(query)[0, :].tolist() |
|
|
|
status_indicator.write(f"Performing query...") |
|
target_field = f"{model_name}_features" |
|
results = es.search( |
|
index="sentences", |
|
query={ |
|
"script_score": { |
|
"query": {"match_all": {}}, |
|
"script": { |
|
"source": f"cosineSimilarity(params.query_vector, '{target_field}') + 1.0", |
|
"params": {"query_vector": query_vector} |
|
} |
|
} |
|
}, |
|
size=limit |
|
) |
|
|
|
for result in results["hits"]["hits"]: |
|
sentence = result['_source']['sentence'] |
|
score = result['_score'] |
|
document = result['_source']['document'] |
|
number = result['_source']['number'] |
|
|
|
previous = es.search( |
|
index="sentences", |
|
query={ |
|
"bool": { |
|
"must": [{ |
|
"term": { |
|
"document": document |
|
} |
|
},{ |
|
"range": { |
|
"number": { |
|
"gte": number-3, |
|
"lt": number, |
|
} |
|
} |
|
} |
|
] |
|
} |
|
} |
|
) |
|
|
|
previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"]) |
|
previous_context = "".join([r["_source"]["sentence"] for r in previous_hits]) |
|
|
|
|
|
subsequent = es.search( |
|
index="sentences", |
|
query={ |
|
"bool": { |
|
"must": [{ |
|
"term": { |
|
"document": document |
|
} |
|
},{ |
|
"range": { |
|
"number": { |
|
"lte": number+3, |
|
"gt": number, |
|
} |
|
} |
|
} |
|
] |
|
} |
|
} |
|
) |
|
|
|
subsequent_hits = sorted(subsequent["hits"]["hits"], key=lambda e: e["_source"]["number"]) |
|
subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits]) |
|
|
|
|
|
document_name_results = es.search( |
|
index="documents", |
|
query={ |
|
"bool": { |
|
"must": [{ |
|
"term": { |
|
"id": document |
|
} |
|
} |
|
] |
|
} |
|
} |
|
) |
|
|
|
document_name_data = document_name_results["hits"]["hits"][0]["_source"] |
|
document_name = f"{document_name_data['title']} - {document_name_data['author']}" |
|
|
|
results_placeholder.markdown(f"#### {document_name} (score: {score:.2f})\n{previous_context} **{sentence}** {subsequent_context}") |
|
|
|
|
|
status_indicator.write(f"Results ready...") |
|
|
|
es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":")) |
|
|
|
st.header("Serica Intelligent Search") |
|
st.write("Perform an intelligent search using a Sentence Embedding Transformer model on the SERICA database") |
|
model_name = st.selectbox("Model", ["LaBSE"]) |
|
limit = st.number_input("Number of results", 10) |
|
query = st.text_input("Query", value="") |
|
status_indicator = st.empty() |
|
do_search = st.button("Search") |
|
results_placeholder = st.container() |
|
|
|
if do_search: |
|
search() |