Spaces:
Runtime error
Runtime error
import streamlit as st | |
# from dash import Dash, dcc, html, dash_table, Input, Output, State | |
import plotly.express as px | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from scipy.spatial import distance | |
import re | |
import textwrap | |
import base64 | |
question = st.text_input('Ask Marcus Aurelius a question') | |
dat = pd.read_csv('meditations_processed.csv') | |
space = pd.read_csv('sentence_embeddings.csv') | |
# For the update function | |
model_name = 'all-mpnet-base-v2' | |
model = SentenceTransformer(model_name) | |
def update_table(value): | |
if not value: | |
return | |
text_coord = model.encode(value, show_progress_bar = True) | |
out = pd.concat([dat.reset_index(), space], axis = 1) | |
cos_dist = [] | |
embedding_cols = [str(i) for i in range(768)] | |
for i in range(0, out.shape[0]): | |
curr = out.iloc[i] | |
curr = curr[embedding_cols] | |
curr_dist = distance.cosine(u = text_coord, v = curr) | |
cos_dist.append(curr_dist) | |
out['cos_dist'] = cos_dist | |
out = out.sort_values('cos_dist') | |
out = out.head(10) | |
out = out[['book', 'verse', 'text', 'cos_dist']] | |
return(out) | |
final = update_table(question) | |
st.table(final) | |