Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from elasticsearch import Elasticsearch | |
import numpy as np | |
import pandas as pd | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.manifold import TSNE | |
import plotly.express as plx | |
def compare(): | |
if len(multiselect) == 0: | |
plot_placeholder.error("Select at least one document") | |
return | |
target_field = f"{model}_features" | |
ids = [documents[title] for title in multiselect] | |
status_indicator.write("Retrieving embeddings...") | |
results = [] | |
for id in ids: | |
results.append(es.search( | |
index="sentences", | |
query={ | |
"constant_score" : { | |
"filter" : { | |
"term" : { | |
"document": id | |
} | |
} | |
} | |
}, | |
size=limit | |
)) | |
status_indicator.write("Merging embeddings...") | |
features = [] | |
classes = [] | |
sentences = [] | |
for result, title in zip(results, multiselect): | |
features.append(np.asarray([sent["_source"][target_field] for sent in result["hits"]["hits"]])) | |
classes.extend([title]*len(result["hits"]["hits"])) | |
sentences.extend([sent["_source"]["sentence"] for sent in result["hits"]["hits"]]) | |
features = np.concatenate(features) | |
status_indicator.write("Computing TSNE...") | |
scaler = StandardScaler() | |
features = scaler.fit_transform(features) | |
tsne = TSNE(n_components=2, metric="cosine", init="pca") | |
features = tsne.fit_transform(features) | |
classes = [c[:10]+"..." for c in classes] | |
df = pd.DataFrame.from_dict(dict( | |
x=features[:, 0], | |
y=features[:, 1], | |
classes=classes, | |
sentences=sentences | |
)) | |
status_indicator.write("All done...") | |
plot_placeholder.plotly_chart(plx.scatter( | |
data_frame=df, | |
x="x", | |
y="y", | |
color="classes", | |
hover_name="sentences" | |
)) | |
es = Elasticsearch(os.environ["ELASTIC_HOST"], basic_auth=os.environ["ELASTIC_AUTH"].split(":")) | |
results = es.search(index="documents", query={"match_all":{}}) | |
results = [result["_source"] for result in results["hits"]["hits"]] | |
documents = {f"{result['title']} - {result['author']}": result['id'] for result in results} | |
st.sidebar.header("Semantic compare") | |
st.sidebar.write("Select documents from the SERICA library to semantically compare them. Hover above the data points to see the respective sentences") | |
multiselect = st.sidebar.multiselect("Documents", list(documents.keys())) | |
model = st.sidebar.selectbox("Model", ["LaBSE"]) | |
limit = st.sidebar.number_input("Sentences per document", 1000) | |
plot_placeholder = st.empty() | |
status_indicator = st.sidebar.empty() | |
if st.sidebar.button("Compare"): | |
compare() | |