matijap's picture
Result index changed to represent order for consistency.
999e0ba
raw
history blame contribute delete
No virus
1.54 kB
import faiss
import joblib
import numpy as np
import pandas as pd
import streamlit as st
from sentence_transformers import SentenceTransformer
@st.experimental_memo
def load_model():
return SentenceTransformer("TamedWicked/MathBERT_hr")
@st.experimental_memo
def load_knowledge_base_df():
return pd.read_parquet("data/knowledge_base.parquet")
@st.experimental_memo
def load_knowledge_base_index():
embeddings = joblib.load("data/knowledge_base_embeddings.pkl")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return index
def vector_search(query: list, model: SentenceTransformer, index, num_results=10):
vector = model.encode(list(query), show_progress_bar=False, convert_to_numpy=True)
D, I = index.search(np.array(vector).astype("float32"), k=num_results)
return D, I
def show_df_as_html(df: pd.DataFrame):
return df.to_html()
def show_df_as_markdown(df: pd.DataFrame):
return df.to_markdown()
model: SentenceTransformer = load_model()
df: pd.DataFrame = load_knowledge_base_df()
knowledge_index: np.array = load_knowledge_base_index()
query = st.text_input("Your math query:", value="Jesu li strukture koje su elementarno ekvivalentne izomorfne?")
if query:
D, I = vector_search([query], model, knowledge_index, num_results=5)
result = df[["Speech", "start_link"]].iloc[I[0]]
result.index = list(range(1, len(result)+1))
speeches = result["Speech"].tolist()
links = result["start_link"].tolist()
st.write(show_df_as_markdown(result))