File size: 4,743 Bytes
02fab3b
1ac2cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d031cb
1ac2cab
 
 
 
 
 
02fab3b
 
f26c37e
02fab3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ac2cab
02fab3b
 
 
 
 
 
 
 
 
 
 
 
1ac2cab
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torch.nn as nn
import torch
from transformers import BertTokenizerFast as BertTokenizer, BertModel
import pytorch_lightning as pl


BERT_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
LABEL_COLUMNS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']


class ToxicCommentTagger(pl.LightningModule):

    def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
        self.n_training_steps = n_training_steps
        self.n_warmup_steps = n_warmup_steps
        self.criterion = nn.BCELoss()


def predict(model, tokenizer, sentence):

    encoding = tokenizer.encode_plus(
        sentence,
        add_special_tokens=False,
        max_length=510,
        return_token_type_ids=False,
        padding="max_length",
        return_attention_mask=True,
        return_tensors='pt'
    )

    # define target chunksize
    chunksize = 512

    # split into chunks of 510 tokens, we also convert to list (default is tuple which is immutable)
    input_id_chunks = list(encoding['input_ids'][0].split(chunksize - 2))
    mask_chunks = list(encoding['attention_mask'][0].split(chunksize - 2))

    # loop through each chunk
    for i in range(len(input_id_chunks)):
        # add CLS and SEP tokens to input IDs
        input_id_chunks[i] = torch.cat([
            torch.tensor([101]), input_id_chunks[i], torch.tensor([102])
        ])
        # add attention tokens to attention mask
        mask_chunks[i] = torch.cat([
            torch.tensor([1]), mask_chunks[i], torch.tensor([1])
        ])
        # get required padding length
        pad_len = chunksize - input_id_chunks[i].shape[0]
        # check if tensor length satisfies required chunk size
        if pad_len > 0:
            # if padding length is more than 0, we must add padding
            input_id_chunks[i] = torch.cat([
                input_id_chunks[i], torch.Tensor([0] * pad_len)
            ])
            mask_chunks[i] = torch.cat([
                mask_chunks[i], torch.Tensor([0] * pad_len)
            ])

    input_ids = torch.stack(input_id_chunks)
    attention_mask = torch.stack(mask_chunks)

    input_dict = {
        'input_ids': input_ids.long(),
        'attention_mask': attention_mask.int()
    }

    _, test_prediction = model(**input_dict)
    test_prediction = test_prediction.numpy()

    output = {}
    for chunk in test_prediction:
        for label, prediction in zip(LABEL_COLUMNS, chunk):
            if label in output:
                output[label] = max(prediction, output[label])
            else:
                output[label] = prediction

    return output


model = ToxicCommentTagger.load_from_checkpoint(
    './best-checkpoint.ckpt',
    n_classes=len(LABEL_COLUMNS)
)

model.eval()
model.freeze()


all_categories = {'all_categories': [
                                    'toxicity',
                                    'obscene',
                                    'threat',
                                    'insult',
                                    'identity_attack',
                                    'sexual_explicit'
                                    ]
                  }


examples = [
    ['Yes indeed. She sort of reminds me of the elder lady that played the part in the movie "Titanic" who was telling her story!!! And I wouldn\'t have wanted to cover who I really am!! I would be proud!!!! WE should be proud of our race no matter what it is!!',
  50],
 ['The trans women reading this tweet right now is beautiful', 50],
 ["Question: These 4 broads who criticize America, what country did they flee to get here? And now they want to make OUR America like THEIR former HELL HOLE. I don't think so!!!!!!!!!!  Let them explain their GRATITUDE for letting them in OUR country.",
  50],
 ['Well being subtle you could not have been born because white people don\'t matter',
  50],
 ["For starters bend over the one in pink and kick that ass and pussy to get a taste until she's begging for a dick inside her.",
  50]
]


def toxicity(sentence, threshold):
    predicts = predict(model, tokenizer, sentence)
    return [ x for x in predicts if predicts[x] > threshold/100 ], all_categories

gr.Interface(fn=toxicity, 
             inputs=[
                 gr.Textbox(placeholder="Enter sentence here..."),
                 gr.Slider(0, 100)
             ], 
             outputs=[
                 'text',
                 gr.JSON(all_categories)
             ],
             examples=examples).launch()