File size: 2,965 Bytes
d7a6200
 
 
 
 
 
 
9316bdf
d7a6200
0677d87
 
c6e39b5
 
 
 
 
9316bdf
 
2b2a52e
 
0677d87
d7a6200
 
 
 
 
 
 
 
 
5a2a0f4
d7a6200
 
0677d87
d7a6200
 
 
 
 
b1cf2dc
d7a6200
 
5a2a0f4
5340b7e
5a2a0f4
 
 
 
 
 
 
 
 
 
f514dcf
5a2a0f4
 
d7a6200
f90f6f8
a4ace7e
b1cf2dc
d7a6200
 
b1cf2dc
3299f3b
 
 
b1cf2dc
d7a6200
 
b1cf2dc
 
 
a4ace7e
 
fdbf9c1
d7a6200
 
 
 
e11f528
d7a6200
 
f589f6b
 
0677d87
f589f6b
0677d87
b1cf2dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import json
import os
import random
import pandas as pd
import streamlit as st
from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline
from transformers_interpret import SequenceClassificationExplainer
import streamlit.components.v1 as components  # Import Streamlit

def visualize(text):

    checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

    cls_explainer = SequenceClassificationExplainer(model,tokenizer)
    word_attributions = cls_explainer(masked_text)
    components.html(cls_explainer.visualize('visualize.html'))
    #components.html('visualize.html')

@st.cache
def load_model(text):
    
    checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

    nlp = pipeline('sentiment-analysis',model=model,tokenizer=tokenizer)
 
    results = nlp(text)
    
    return results
    #MASK_TOKEN = tokenizer.mask_token
    #masked_text = masked_text.replace("<mask>", MASK_TOKEN)
    #result_sentence = nlp(masked_text)
    #return result_sentence[0]["sequence"], result_sentence[0]["token_str"]

import re
def app():
    st.title("OGBV-BERT")
    
    st.sidebar.markdown("""**Warning**: The Data contains offensive text""")
    data = st.sidebar.radio("Pick the evaluation data :",('Twitter','Trac2020'))
    
    if data=="Twitter":
    
        target_text_path = "./input/tweet_list.csv"
        target_text_df = pd.read_csv(target_text_path)
        texts = target_text_df["text"]
        
    else:
    
        target_text_path = "trac2_hin_test.csv"
        target_text_df = pd.read_csv(target_text_path)
        texts = target_text_df["Text"]
    
    pick_random = st.sidebar.checkbox("Pick any random text")
    
   
    if pick_random:
        random_text = texts[random.randint(0, texts.shape[0] - 1)]
        text = re.sub('@[^\s]+','',random_text)
        if data=="Twitter":
            text = text[3:]
        
        masked_text = st.text_area("Please type a sentence to classify", text)
    else:
        select_text = st.sidebar.selectbox("Select any of the following text", texts)
        text = re.sub('@[^\s]+','',select_text)
        text = text[3:]
        masked_text = st.text_area("Please type a sentence to classify", text)
    
    st.sidebar.markdown("""Find out more at [Github](https://github.com/mlkorra/OGBV-detection)""")  
   
    
    # pd.set_option('max_colwidth',30)
    if st.button("Classify"):
        with st.spinner("Classifying the sentence..."):
            pred = load_model(masked_text)
            st.write(pred)

            if st.button('Visualize attributions'):
                with st.spinner("Visualizing .....") :
                    visualize(masked_text)

                   

if __name__ == "__main__":
    app()