import json import os import random import pandas as pd import streamlit as st from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline from transformers_interpret import SequenceClassificationExplainer import streamlit.components.v1 as components # Import Streamlit def visualize(text): checkpoint = 'mlkorra/OGBV-gender-bert-hi-en' tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSequenceClassification.from_pretrained(checkpoint) cls_explainer = SequenceClassificationExplainer(model,tokenizer) word_attributions = cls_explainer(masked_text) components.html(cls_explainer.visualize('visualize.html')) #components.html('visualize.html') @st.cache def load_model(text): checkpoint = 'mlkorra/OGBV-gender-bert-hi-en' tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSequenceClassification.from_pretrained(checkpoint) nlp = pipeline('sentiment-analysis',model=model,tokenizer=tokenizer) results = nlp(text) return results #MASK_TOKEN = tokenizer.mask_token #masked_text = masked_text.replace("", MASK_TOKEN) #result_sentence = nlp(masked_text) #return result_sentence[0]["sequence"], result_sentence[0]["token_str"] import re def app(): st.title("OGBV-BERT") st.sidebar.markdown("""**Warning**: The Data contains offensive text""") data = st.sidebar.radio("Pick the evaluation data :",('Twitter','Trac2020')) if data=="Twitter": target_text_path = "./input/tweet_list.csv" target_text_df = pd.read_csv(target_text_path) texts = target_text_df["text"] else: target_text_path = "trac2_hin_test.csv" target_text_df = pd.read_csv(target_text_path) texts = target_text_df["Text"] pick_random = st.sidebar.checkbox("Pick any random text") if pick_random: random_text = texts[random.randint(0, texts.shape[0] - 1)] text = re.sub('@[^\s]+','',random_text) if data=="Twitter": text = text[3:] masked_text = st.text_area("Please type a sentence to classify", text) else: select_text = st.sidebar.selectbox("Select any of the following text", texts) text = re.sub('@[^\s]+','',select_text) text = text[3:] masked_text = st.text_area("Please type a sentence to classify", text) st.sidebar.markdown("""Find out more at [Github](https://github.com/mlkorra/OGBV-detection)""") # pd.set_option('max_colwidth',30) if st.button("Classify"): with st.spinner("Classifying the sentence..."): pred = load_model(masked_text) st.write(pred) if st.button('Visualize attributions'): with st.spinner("Visualizing .....") : visualize(masked_text) if __name__ == "__main__": app()