File size: 2,188 Bytes
d7a6200
 
 
 
 
 
 
 
0677d87
 
 
 
 
 
 
 
d7a6200
 
 
 
 
 
 
 
 
 
 
 
 
 
0677d87
d7a6200
 
 
 
 
b1cf2dc
d7a6200
 
b1cf2dc
d7a6200
 
 
 
 
 
b1cf2dc
d7a6200
 
b1cf2dc
 
 
 
d7a6200
 
b1cf2dc
 
 
d7a6200
 
 
 
 
e11f528
d7a6200
 
f589f6b
 
0677d87
f589f6b
0677d87
b1cf2dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import os
import random
import pandas as pd
import streamlit as st
from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline
from transformers_interpret import SequenceClassificationExplainer

@st.cache
def visualize(text):

     cls_explainer = SequenceClassificationExplainer(model,tokenizer)
     word_attributions = cls_explainer(masked_text)
     cls_explainer.visualize('visualize.html')


@st.cache
def load_model(text):
    

    checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

    nlp = pipeline('sentiment-analysis',model=model,tokenizer=tokenizer)

  
    results = nlp(text)
    
    return results
    #MASK_TOKEN = tokenizer.mask_token
    #masked_text = masked_text.replace("<mask>", MASK_TOKEN)
    #result_sentence = nlp(masked_text)
    #return result_sentence[0]["sequence"], result_sentence[0]["token_str"]

import re
def app():
    st.title("OGBV-BERT")
     
    target_text_path = "./input/tweet_list.csv"
    target_text_df = pd.read_csv(target_text_path)
    texts = target_text_df["text"]
    st.sidebar.title("Place")
    pick_random = st.sidebar.checkbox("Pick any random text")
    
   
    if pick_random:
        random_text = texts[random.randint(0, texts.shape[0] - 1)]
        text = re.sub('@[^\s]+','',random_text)
        text = text[3:]
  
        masked_text = st.text_area("Please type a sentence to classify", text)
    else:
        select_text = st.sidebar.selectbox("Select any of the following text", texts)
        text = re.sub('@[^\s]+','',select_text)
        text = text[3:]
        masked_text = st.text_area("Please type a sentence to classify", text)
    
    
    # pd.set_option('max_colwidth',30)
    if st.button("Classify"):
        with st.spinner("Classifying the sentence..."):
            pred = load_model(masked_text)
            st.write(pred)

            if st.button('Visualize attributions'):
                with st.spinner("Visualizing .....") :
                    visualize(masked_text)

                   

if __name__ == "__main__":
    app()