espejelomar's picture
Second version
d25a081
import streamlit as st
import pandas as pd
from backend import inference
from backend.config import MODELS_ID, QA_MODELS_ID, SEARCH_MODELS_ID
st.title('Demo using Flax-Sentence-Tranformers')
st.sidebar.title('Tasks')
menu = st.sidebar.radio("", options=['Identifying misleading vaccine texts'], index=0)
st.markdown('''
Hi! This is the demo for the [flax sentence embeddings](https://huggingface.co/flax-sentence-embeddings) created for the **Flax/JAX community week 🤗**.
We trained three general-purpose flax-sentence-embeddings models: a **distilroberta base**, a **mpnet base** and a **minilm-l6**.
All were trained on all the dataset of the 1B+ train corpus with the v3 setup.
In addition, we trained 20 models focused on general-purpose, QuestionAnswering and Codesearch.
View our models here : https://huggingface.co/flax-sentence-embeddings
''')
if menu == "Identifying misleading vaccine texts":
st.header('Identifying misleading vaccine texts')
st.markdown('''
**Instructions**: You can compare the similarity of a given text and key words that identify 'misleading' texts regarding vaccination. In the background, we'll create an embedding for each text, and then we'll use the cosine similarity function to calculate a similarity metric between our main sentence and the keywords.
We use keywords identified by **Muric, Goran and Wu, Yusong and Ferrara, Emilio (2021), 'COVID-19 Vaccine Hesitancy on Social Media: Building a Public Twitter Dataset of Anti-vaccine Content, Vaccine Misinformation and Conspiracies'**
For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
''')
select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
anchor = st.text_input(
'Please enter here the text/tweet you want to evaluate:'
)
if st.button('Tell me the similarity.'):
results = {model: inference.tweets_vaccine(anchor, model, MODELS_ID) for model in select_models}
df_results = {model: results[model] for model in results}
#index = [f"{idx + 1}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
df_total = pd.DataFrame(index=[0])
for key, value in df_results.items():
df_total[key] = list(value['score'].values)
st.write('Here are the results for selected models:')
st.write(df_total)