File size: 3,877 Bytes
168a4de
 
 
 
 
 
 
a09dcd0
168a4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e7bd6
cd4685f
168a4de
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import streamlit as st
from elasticsearch import Elasticsearch

from embedders.labse import LaBSE

def search():
    status_indicator.write(f"Loading model {model_name} (it can take ~1 minute the first time)...")
    model = globals()[model_name]()
    
    status_indicator.write(f"Computing query embeddings...")
    query_vector = model(query)[0, :].tolist()
    
    status_indicator.write(f"Performing query...")
    target_field = f"{model_name}_features"
    results = es.search(
        index="sentences",
        query={
            "script_score": {
                "query": {"match_all": {}},
                "script": {
                    "source": f"cosineSimilarity(params.query_vector, '{target_field}') + 1.0",
                    "params": {"query_vector": query_vector}
                }
            }
        },
        size=limit
    )

    for result in results["hits"]["hits"]:
        sentence = result['_source']['sentence']
        score =  result['_score']
        document = result['_source']['document']
        number = result['_source']['number']

        previous = es.search(
            index="sentences",
            query={
                "bool": {
                    "must": [{
                        "term": {
                            "document": document
                            }
                        },{
                        "range": {
                            "number": {
                            "gte": number-3,
                            "lt": number,
                                }
                            }
                        }
                    ]
                }
            }
        )

        previous_hits = sorted(previous["hits"]["hits"], key=lambda e: e["_source"]["number"])
        previous_context = "".join([r["_source"]["sentence"] for r in previous_hits])


        subsequent = es.search(
            index="sentences",
            query={
                "bool": {
                    "must": [{
                        "term": {
                            "document": document
                            }
                        },{
                        "range": {
                            "number": {
                            "lte": number+3,
                            "gt": number,
                                }
                            }
                        }
                    ]
                }
            }
        )

        subsequent_hits = sorted(subsequent["hits"]["hits"], key=lambda e: e["_source"]["number"])
        subsequent_context = "".join([r["_source"]["sentence"] for r in subsequent_hits])


        document_name_results = es.search(
            index="documents",
            query={
                "bool": {
                    "must": [{
                        "term": {
                            "id": document
                            }
                        }
                    ]
                }
            }
        )

        document_name_data = document_name_results["hits"]["hits"][0]["_source"]
        document_name = f"{document_name_data['title']} - {document_name_data['author']}"

        results_placeholder.markdown(f"#### {document_name} (score: {score:.2f})\n{previous_context} **{sentence}** {subsequent_context}")


    status_indicator.write(f"Results ready...")

es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":"))

st.header("Serica Intelligent Search")
st.write("Perform an intelligent search using a Sentence Embedding Transformer model on the SERICA database")
model_name = st.selectbox("Model", ["LaBSE"])
limit = st.number_input("Number of results", 10)
query = st.text_input("Query", value="")
status_indicator = st.empty()
do_search = st.button("Search")
results_placeholder = st.container()

if do_search:
    search()