Spaces:
Build error
Build error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers.util import cos_sim | |
from sentence_transformers import SentenceTransformer | |
from bokeh.plotting import figure, output_notebook, show, save | |
from bokeh.io import output_file, show | |
from bokeh.models import ColumnDataSource, HoverTool | |
from sklearn.manifold import TSNE | |
def load_model(): | |
model = SentenceTransformer('hackathon-pln-es/bertin-roberta-base-finetuning-esnli') | |
model.eval() | |
return model | |
def load_plot_data(): | |
embs = np.load('semeval2015-embs.npy') | |
data = pd.read_csv('semeval2015-data.csv') | |
return embs, data | |
st.title("Sentence Embedding for Spanish with Bertin") | |
st.write("Sentence embedding for spanish trained on NLI. Used for Sentence Textual Similarity. Based on the model hackathon-pln-es/bertin-roberta-base-finetuning-esnli.") | |
st.write("Introduce two sentence to see their cosine similarity and a graph showing them in the embedding space.") | |
st.write("Authors: Anibal Pérez, Emilio Tomás Ariza, Lautaro Gesuelli y Mauricio Mazuecos.") | |
sent1 = st.text_area('Enter sentence 1') | |
sent2 = st.text_area('Enter sentence 2') | |
if st.button('Compute similarity'): | |
if sent1 and sent2: | |
model = load_model() | |
encodings = model.encode([sent1, sent2]) | |
sim = cos_sim(encodings[0], encodings[1]).numpy().tolist()[0][0] | |
st.text('Cosine Similarity: {0:.4f}'.format(sim)) | |
print('Generating visualization...') | |
sentembs, data = load_plot_data() | |
X_embedded = TSNE(n_components=2, learning_rate='auto', | |
init='random').fit_transform(np.concatenate([sentembs, encodings], axis=0)) | |
data = data.append({'sent': sent1, 'color': '#F0E442'}, ignore_index=True) # sentence 1 | |
data = data.append({'sent': sent2, 'color': '#D55E00'}, ignore_index=True) # sentence 2 | |
data['x'] = X_embedded[:,0] | |
data['y'] = X_embedded[:,1] | |
source = ColumnDataSource(data) | |
p = figure(title="Embeddings in space") | |
p.circle( | |
x='x', | |
y='y', | |
legend_label="Objects", | |
#fill_color=["red"], | |
color='color', | |
fill_alpha=0.5, | |
line_color="blue", | |
size=14, | |
source=source | |
) | |
p.add_tools(HoverTool( | |
tooltips=[ | |
('sent', '@sent') | |
], | |
formatters={ | |
'@sent': 'printf' | |
}, | |
mode='mouse' | |
)) | |
st.bokeh_chart(p, use_container_width=True) | |
else: | |
st.write('Missing a sentences') | |
else: | |
pass | |