fnavales commited on
Commit
464d4fc
1 Parent(s): 8235435

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -107
app.py CHANGED
@@ -1,112 +1,5 @@
1
  import gradio as gr
2
- <<<<<<< HEAD
3
  from detoxify import Detoxify
4
- =======
5
- import torch.nn as nn
6
- import torch
7
- from transformers import BertTokenizerFast as BertTokenizer, BertModel
8
- import pytorch_lightning as pl
9
-
10
-
11
- BERT_MODEL_NAME = 'bert-base-uncased'
12
- tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
13
- LABEL_COLUMNS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
14
- MAX_TOKEN_COUNT = 300
15
-
16
-
17
- class ToxicCommentTagger(pl.LightningModule):
18
-
19
- def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
20
- super().__init__()
21
- self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
22
- self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
23
- self.n_training_steps = n_training_steps
24
- self.n_warmup_steps = n_warmup_steps
25
- self.criterion = nn.BCELoss()
26
-
27
-
28
- def forward(self, input_ids, attention_mask, labels=None):
29
- output = self.bert(input_ids, attention_mask=attention_mask)
30
- output = self.classifier(output.pooler_output)
31
- output = torch.sigmoid(output)
32
- loss = 0
33
- if labels is not None:
34
- loss = self.criterion(output, labels)
35
- return loss, output
36
-
37
-
38
- def predict(model, tokenizer, sentence):
39
-
40
- encoding = tokenizer.encode_plus(
41
- sentence,
42
- add_special_tokens=False,
43
- max_length=MAX_TOKEN_COUNT,
44
- return_token_type_ids=False,
45
- padding="max_length",
46
- return_attention_mask=True,
47
- return_tensors='pt'
48
- )
49
-
50
- # define target chunksize
51
- chunksize = MAX_TOKEN_COUNT
52
-
53
- # split into chunks of 510 tokens, we also convert to list (default is tuple which is immutable)
54
- input_id_chunks = list(encoding['input_ids'][0].split(chunksize - 2))
55
- mask_chunks = list(encoding['attention_mask'][0].split(chunksize - 2))
56
-
57
- # loop through each chunk
58
- for i in range(len(input_id_chunks)):
59
- # add CLS and SEP tokens to input IDs
60
- input_id_chunks[i] = torch.cat([
61
- torch.tensor([101]), input_id_chunks[i], torch.tensor([102])
62
- ])
63
- # add attention tokens to attention mask
64
- mask_chunks[i] = torch.cat([
65
- torch.tensor([1]), mask_chunks[i], torch.tensor([1])
66
- ])
67
- # get required padding length
68
- pad_len = chunksize - input_id_chunks[i].shape[0]
69
- # check if tensor length satisfies required chunk size
70
- if pad_len > 0:
71
- # if padding length is more than 0, we must add padding
72
- input_id_chunks[i] = torch.cat([
73
- input_id_chunks[i], torch.Tensor([0] * pad_len)
74
- ])
75
- mask_chunks[i] = torch.cat([
76
- mask_chunks[i], torch.Tensor([0] * pad_len)
77
- ])
78
-
79
- input_ids = torch.stack(input_id_chunks)
80
- attention_mask = torch.stack(mask_chunks)
81
-
82
- input_dict = {
83
- 'input_ids': input_ids.long(),
84
- 'attention_mask': attention_mask.int()
85
- }
86
-
87
- _, test_prediction = model(**input_dict)
88
- test_prediction = test_prediction.numpy()
89
-
90
- output = {}
91
- for chunk in test_prediction:
92
- for label, prediction in zip(LABEL_COLUMNS, chunk):
93
- if label in output:
94
- output[label] = max(prediction, output[label])
95
- else:
96
- output[label] = prediction
97
-
98
- return output
99
-
100
-
101
- model = ToxicCommentTagger.load_from_checkpoint(
102
- './best-checkpoint.ckpt',
103
- n_classes=len(LABEL_COLUMNS)
104
- )
105
-
106
- model.eval()
107
- model.freeze()
108
-
109
- >>>>>>> 2a04af3d9d5ddbaa3eb1631c0e56d215462a7e36
110
 
111
  all_categories = {'all_categories': [
112
  'toxicity',
 
1
  import gradio as gr
 
2
  from detoxify import Detoxify
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  all_categories = {'all_categories': [
5
  'toxicity',