File size: 3,820 Bytes
85e396d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import gradio as gr
from torch.nn import functional as F
import seaborn

import matplotlib
import platform

from transformers.file_utils import ModelOutput

if platform.system() == "Darwin":
    print("MacOS")
    matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
from PIL import Image

import matplotlib.font_manager as fm
import util


# global var
MODEL_NAME = 'jason9693/SoongsilBERT-base-beep'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
config = AutoConfig.from_pretrained(MODEL_NAME)

MODEL_BUF = {
    "name": MODEL_NAME,
    "tokenizer": tokenizer,
    "model": model,
    "config": config
}


font_dir = ['./']
for font in fm.findSystemFonts(font_dir):
    print(font)
    fm.fontManager.addfont(font)
plt.rcParams["font.family"] = 'NanumGothicCoding'


def visualize_attention(sent, attention_matrix, n_words=10):
    def draw(data, x, y, ax):

        seaborn.heatmap(data, 
                        xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0, 
                        cbar=False, ax=ax)
        
    # make plt figure with 1x6 subplots
    fig = plt.figure(figsize=(16, 8))
    # fig.subplots_adjust(hspace=0.7, wspace=0.2)
    for i, layer in enumerate(range(1, 12, 2)):
        ax = fig.add_subplot(2, 3, i+1)
        ax.set_title("Layer {}".format(layer))
        draw(attention_matrix[layer], sent if layer > 6 else [], sent if layer in [1,7] else [], ax=ax)
 
    fig.tight_layout()
    plt.close()
    return fig


def change_model_name(name):
    MODEL_BUF["name"] = name
    MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name)
    MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name)
    MODEL_BUF["config"] = AutoConfig.from_pretrained(name)


def predict(model_name, text):
    if model_name != MODEL_BUF["name"]:
        change_model_name(model_name)
    
    tokenizer = MODEL_BUF["tokenizer"]
    model = MODEL_BUF["model"]
    config = MODEL_BUF["config"]

    tokenized_text = tokenizer([text], return_tensors='pt')

    input_tokens = tokenizer.convert_ids_to_tokens(tokenized_text.input_ids[0])
    try:
        input_tokens = util.bytetokens_to_unicdode(input_tokens) if config.model_type in ['roberta', 'gpt', 'gpt2'] else input_tokens
    except KeyError:
        input_tokens = input_tokens

    model.eval()
    output, attention = model(**tokenized_text, output_attentions=True, return_dict=False)
    output = F.softmax(output, dim=-1)
    result = {}
    
    for idx, label in enumerate(output[0].detach().numpy()):
        result[config.id2label[idx]] = float(label)

    fig = visualize_attention(input_tokens, attention[0][0].detach().numpy())
    return result, fig#.logits.detach()#.numpy()#, output.attentions.detach().numpy()


if __name__ == '__main__':
    text = '읿딴걸 홍볿글 읿랉곭 쌑젩낄고 앉앟있냩'

    model_name_list = [
        'jason9693/SoongsilBERT-base-beep',
        "beomi/beep-klue-roberta-base-hate",
        "beomi/beep-koelectra-base-v3-discriminator-hate",
        "beomi/beep-KcELECTRA-base-hate"
    ]

    #Create a gradio app with a button that calls predict()
    app = gr.Interface(
        fn=predict,
        inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=['label', 'plot'], 
        examples = [[MODEL_BUF["name"], text], [MODEL_BUF["name"], "4=🦀 4≠🦀"]],
        title="한국어 혐오성 발화 분류기 (Korean Hate Speech Classifier)",
        description="Korean Hate Speech Classifier with Several Pretrained LM\nCurrent Supported Model:\n1. SoongsilBERT\n2. KcBERT(+KLUE)\n3. KcELECTRA\n4.KoELECTRA."
        )
    app.launch(inline=False)