w11wo's picture
initial commit
aa805a6
raw history blame
No virus
1.93 kB
import streamlit as st
from huggingface_hub import InferenceApi
import pandas as pd
from transformers import pipeline
STYLE = """
<style>
img {
max-width: 100%;
}
th {
text-align: left!important
}
</style>
"""
MASK_TOKEN = "<mask>"
def display_table(df):
st.subheader("Top 5 Prediction.")
df.drop(columns=["token", "token_str"], inplace=True)
df = df.style.set_properties(subset=["sequence", "score"], **{"text-align": "left"})
st.table(df)
def main():
st.markdown(STYLE, unsafe_allow_html=True)
st.title("Indonesian RoBERTa Base")
user_input = st.text_input("Insert a sentence to predict with a mask token: <mask>")
mask_api = InferenceApi("flax-community/indonesian-roberta-base")
emot_name = "StevenLimcorn/indonesian-roberta-base-emotion-classifier"
emot_pipeline = pipeline("sentiment-analysis", model=emot_name, tokenizer=emot_name)
if len(user_input) > 0:
try:
user_input.index(MASK_TOKEN)
except ValueError:
st.error("Please enter a sentence with the correct mask token: <mask>")
else:
# A List of dict with keys: sequence, score, token, token_str
result = mask_api(inputs=user_input)
df = pd.DataFrame(result)
display_table(df)
# emot
st.subheader("Emotion Analysis of the Top 5 Prediction")
emot_df = pd.DataFrame(columns=["sequence", "label", "score"])
for sequence in df["sequence"].values:
emot_output = emot_pipeline(sequence)
result_dict = {"sequence": sequence}
result_dict.update(emot_output[0])
emot_df = emot_df.append(result_dict, ignore_index=True)
emot_df = emot_df.style.set_properties(
subset=["sequence", "label", "score"], **{"text-align": "left"}
)
st.table(emot_df)
main()