OGBV-Bert / app.py
mlkorra's picture
Update app.py
5340b7e
import json
import os
import random
import pandas as pd
import streamlit as st
from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline
from transformers_interpret import SequenceClassificationExplainer
import streamlit.components.v1 as components # Import Streamlit
def visualize(text):
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
cls_explainer = SequenceClassificationExplainer(model,tokenizer)
word_attributions = cls_explainer(masked_text)
components.html(cls_explainer.visualize('visualize.html'))
#components.html('visualize.html')
@st.cache
def load_model(text):
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
nlp = pipeline('sentiment-analysis',model=model,tokenizer=tokenizer)
results = nlp(text)
return results
#MASK_TOKEN = tokenizer.mask_token
#masked_text = masked_text.replace("<mask>", MASK_TOKEN)
#result_sentence = nlp(masked_text)
#return result_sentence[0]["sequence"], result_sentence[0]["token_str"]
import re
def app():
st.title("OGBV-BERT")
st.sidebar.markdown("""**Warning**: The Data contains offensive text""")
data = st.sidebar.radio("Pick the evaluation data :",('Twitter','Trac2020'))
if data=="Twitter":
target_text_path = "./input/tweet_list.csv"
target_text_df = pd.read_csv(target_text_path)
texts = target_text_df["text"]
else:
target_text_path = "trac2_hin_test.csv"
target_text_df = pd.read_csv(target_text_path)
texts = target_text_df["Text"]
pick_random = st.sidebar.checkbox("Pick any random text")
if pick_random:
random_text = texts[random.randint(0, texts.shape[0] - 1)]
text = re.sub('@[^\s]+','',random_text)
if data=="Twitter":
text = text[3:]
masked_text = st.text_area("Please type a sentence to classify", text)
else:
select_text = st.sidebar.selectbox("Select any of the following text", texts)
text = re.sub('@[^\s]+','',select_text)
text = text[3:]
masked_text = st.text_area("Please type a sentence to classify", text)
st.sidebar.markdown("""Find out more at [Github](https://github.com/mlkorra/OGBV-detection)""")
# pd.set_option('max_colwidth',30)
if st.button("Classify"):
with st.spinner("Classifying the sentence..."):
pred = load_model(masked_text)
st.write(pred)
if st.button('Visualize attributions'):
with st.spinner("Visualizing .....") :
visualize(masked_text)
if __name__ == "__main__":
app()