context-game / app.py
Allob's picture
Update app.py
b5f7332
raw
history blame contribute delete
No virus
1.83 kB
import streamlit as st
import plotly.express as px
import pandas as pd
import random
import logging
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
@st.cache_resource
def load_model(name):
return SentenceTransformer(name)
@st.cache_data
def load_words_dataset():
dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train")
return dataset["Word"]
@st.cache_data
def choose_secret_word():
all_words = load_words_dataset()
return random.choice(all_words)
all_words = load_words_dataset()
model_names = [
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
'BAAI/bge-small-en-v1.5'
]
models = {
name: load_model(name) for name in model_names
}
secret_word =choose_secret_word().lower().strip()
secret_embedding = [models[name].encode(secret_word) for name in model_names]
print("Secret word ", secret_word)
if 'words' not in st.session_state:
st.session_state['words'] = []
st.write('Try to guess a secret word by semantic similarity')
word = st.text_input("Input a word")
used_words = [w[0] for w in st.session_state['words']]
if st.button("Guess") or word:
if word not in used_words:
word_embedding = [models[name].encode(word.lower().strip()) for name in model_names]
similarities = [util.pytorch_cos_sim(secret_embedding[i], word_embedding[i]).cpu().numpy()[0][0] for i, name in enumerate(model_names)]
st.session_state['words'].append([str(word)] + similarities)
words_df = pd.DataFrame(
st.session_state['words'],
columns=["word"] + ["Similarity for " + name for name in model_names]
).sort_values(by=["Similarity for sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"], ascending=False)
st.dataframe(words_df, use_container_width=True)