File size: 5,031 Bytes
e66572a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2da1185
e66572a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2da1185
e66572a
 
2da1185
e66572a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2da1185
e66572a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import pandas as pd
import os
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
import streamlit as st
if 'df' not in st.session_state:
    st.session_state.df = pd.DataFrame(columns=['Tweet', 'Toxicity Class', 'Probability'])


from torch.nn import BCEWithLogitsLoss

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.loss_fct = BCEWithLogitsLoss()

    def forward(self, input_ids, attention_mask, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds)
        pooled_output = outputs[1]
        logits = self.classifier(pooled_output)
        outputs = (logits,) + outputs[2:]

        if labels is not None:
            labels = labels.to(dtype=torch.float32)
            loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            outputs = (loss,) + outputs
            return outputs
        else:
            return logits

def load_data(file_path):
    datA = pd.read_csv(file_path)
    return datA

def tokenize_data(data, tokenizer):
    return tokenizer(data['comment_text'].tolist(), padding=True, truncation=True, max_length=256, return_tensors='pt')

class ToxicDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float32)
        return item

    def __len__(self):
        return len(self.labels)

def train_model(model, tokenizer, dataset):
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=3,
        per_device_train_batch_size=48,
        logging_dir='./logs',
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
    )

    trainer.train()


@st.cache(allow_output_mutation=True)
def append_to_dataframe(df,append_row):
    print('called append_to_dataframe')
    df = pd.concat([df,append_row], ignore_index=True)
    return df

@st.cache(allow_output_mutation=True)
def load_and_train_model():
    model_save_path = './fine_tuned_model'
    tokenizer_save_path = './fine_tuned_tokenizer'
    if os.path.exists(model_save_path):
        print('loading existing model')
        model = CustomBertForSequenceClassification.from_pretrained(model_save_path)
        tokenizer = BertTokenizer.from_pretrained(tokenizer_save_path)
        return model, tokenizer
    print("Loading dataset...")
    file_path = r'train.csv'
    data = load_data(file_path)
    labels = data.iloc[:, 2:].values.tolist()

    print("Tokenize.")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    config = BertConfig.from_pretrained('bert-base-uncased', num_labels=6)
    config.hidden_size = 128
    config.num_attention_heads = 2
    config.intermediate_size = 512
    config.num_hidden_layers = 2

    model = CustomBertForSequenceClassification(config)

    print("Fine-tuning BERT model...")
    print('tokenizing')
    encodings = tokenize_data(data, tokenizer)
    print('dataset=')
    dataset = ToxicDataset(encodings, labels)    
    print('starting training...')
    train_model(model, tokenizer, dataset)  
    print('saving...')
    model.save_pretrained(model_save_path)
    tokenizer.save_pretrained(tokenizer_save_path)
    return model, tokenizer

st.title("Toxic Tweet Classifier")

model, tokenizer = load_and_train_model()

model_options = ['BERT Fine-Tuned'] 
selected_model = st.selectbox("Select the fine-tuned model:", model_options)

input_tweet = st.text_input("Enter the text below:")

if st.button("Classify"):
    with st.spinner("Classifying..."):
        inputs = tokenizer(input_tweet, return_tensors='pt', padding=True, truncation=True,max_length=256)
        logits = model(**inputs)
        probabilities = torch.softmax(logits, dim=1).tolist()[0]
        
        label_prob = max(zip(model.config.id2label.values(), probabilities), key=lambda x: x[1])

        label_map = {
            "LABEL_0": "Toxic",
            "LABEL_1": "Severe Toxic",
            "LABEL_2": "Obscene",
            "LABEL_3": "Threat",
            "LABEL_4": "Insult",
            "LABEL_5": "Identity Hate"
        }
        print('Insert into table')
        st.write(input_tweet)
        st.write(label_map[label_prob[0]])
        st.write(label_prob[1])
        st.session_state.df = append_to_dataframe(st.session_state.df,pd.DataFrame({'Tweet': [input_tweet], 'Toxicity Class': [label_map[label_prob[0]]], 'Probability': [label_prob[1]]}))
        st.write(st.session_state.df)