File size: 2,591 Bytes
d7a6200 9316bdf d7a6200 0677d87 c6e39b5 9316bdf 2b2a52e 0677d87 d7a6200 0677d87 d7a6200 b1cf2dc d7a6200 b1cf2dc d7a6200 ddcd1d9 d7a6200 f90f6f8 65eb4da b1cf2dc d7a6200 b1cf2dc d7a6200 b1cf2dc 65eb4da d7a6200 e11f528 d7a6200 f589f6b 0677d87 f589f6b 0677d87 b1cf2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import json
import os
import random
import pandas as pd
import streamlit as st
from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline
from transformers_interpret import SequenceClassificationExplainer
import streamlit.components.v1 as components # Import Streamlit
def visualize(text):
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
cls_explainer = SequenceClassificationExplainer(model,tokenizer)
word_attributions = cls_explainer(masked_text)
components.html(cls_explainer.visualize('visualize.html'))
#components.html('visualize.html')
@st.cache
def load_model(text):
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
nlp = pipeline('sentiment-analysis',model=model,tokenizer=tokenizer)
results = nlp(text)
return results
#MASK_TOKEN = tokenizer.mask_token
#masked_text = masked_text.replace("<mask>", MASK_TOKEN)
#result_sentence = nlp(masked_text)
#return result_sentence[0]["sequence"], result_sentence[0]["token_str"]
import re
def app():
st.title("OGBV-BERT")
target_text_path = "./input/tweet_list.csv"
target_text_df = pd.read_csv(target_text_path)
texts = target_text_df["text"]
##st.sidebar.title("Place")
pick_random = st.sidebar.checkbox("Pick any random text")
if pick_random:
random_text = texts[random.randint(0, texts.shape[0] - 1)]
text = re.sub('@[^\s]+','',random_text)
text = text[3:]
masked_text = st.text_area("Please type a sentence to classify", text)
else:
select_text = st.sidebar.selectbox("Select any of the following text", texts)
text = re.sub('@[^\s]+','',select_text)
text = text[3:]
masked_text = st.text_area("Please type a sentence to classify", text)
st.sidebar.markdown("""Find out more at [Github](https://github.com/mlkorra/OGBV-detection)""")
# pd.set_option('max_colwidth',30)
if st.button("Classify"):
with st.spinner("Classifying the sentence..."):
pred = load_model(masked_text)
st.write(pred)
if st.button('Visualize attributions'):
with st.spinner("Visualizing .....") :
visualize(masked_text)
if __name__ == "__main__":
app() |