|
import pandas as pd |
|
import os |
|
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments |
|
import torch |
|
import streamlit as st |
|
if 'df' not in st.session_state: |
|
st.session_state.df = pd.DataFrame(columns=['Tweet', 'Toxicity Class', 'Probability']) |
|
|
|
|
|
from torch.nn import BCEWithLogitsLoss |
|
|
|
class CustomBertForSequenceClassification(BertForSequenceClassification): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.loss_fct = BCEWithLogitsLoss() |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None): |
|
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds) |
|
pooled_output = outputs[1] |
|
logits = self.classifier(pooled_output) |
|
outputs = (logits,) + outputs[2:] |
|
|
|
if labels is not None: |
|
labels = labels.to(dtype=torch.float32) |
|
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) |
|
outputs = (loss,) + outputs |
|
return outputs |
|
else: |
|
return logits |
|
|
|
def load_data(file_path): |
|
datA = pd.read_csv(file_path) |
|
return datA |
|
|
|
def tokenize_data(data, tokenizer): |
|
return tokenizer(data['comment_text'].tolist(), padding=True, truncation=True, max_length=256, return_tensors='pt') |
|
|
|
class ToxicDataset(torch.utils.data.Dataset): |
|
def __init__(self, encodings, labels): |
|
self.encodings = encodings |
|
self.labels = labels |
|
|
|
def __getitem__(self, idx): |
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
|
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float32) |
|
return item |
|
|
|
def __len__(self): |
|
return len(self.labels) |
|
|
|
def train_model(model, tokenizer, dataset): |
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
num_train_epochs=3, |
|
per_device_train_batch_size=48, |
|
logging_dir='./logs', |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def append_to_dataframe(df,append_row): |
|
print('called append_to_dataframe') |
|
df = pd.concat([df,append_row], ignore_index=True) |
|
return df |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_and_train_model(): |
|
model_save_path = './fine_tuned_model' |
|
tokenizer_save_path = './fine_tuned_tokenizer' |
|
if os.path.exists(model_save_path): |
|
print('loading existing model') |
|
model = CustomBertForSequenceClassification.from_pretrained(model_save_path) |
|
tokenizer = BertTokenizer.from_pretrained(tokenizer_save_path) |
|
return model, tokenizer |
|
print("Loading dataset...") |
|
file_path = r'train.csv' |
|
data = load_data(file_path) |
|
labels = data.iloc[:, 2:].values.tolist() |
|
|
|
print("Tokenize.") |
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=6) |
|
config.hidden_size = 128 |
|
config.num_attention_heads = 2 |
|
config.intermediate_size = 512 |
|
config.num_hidden_layers = 2 |
|
|
|
model = CustomBertForSequenceClassification(config) |
|
|
|
print("Fine-tuning BERT model...") |
|
print('tokenizing') |
|
encodings = tokenize_data(data, tokenizer) |
|
print('dataset=') |
|
dataset = ToxicDataset(encodings, labels) |
|
print('starting training...') |
|
train_model(model, tokenizer, dataset) |
|
print('saving...') |
|
model.save_pretrained(model_save_path) |
|
tokenizer.save_pretrained(tokenizer_save_path) |
|
return model, tokenizer |
|
|
|
st.title("Toxic Tweet Classifier") |
|
|
|
model, tokenizer = load_and_train_model() |
|
|
|
model_options = ['BERT Fine-Tuned'] |
|
selected_model = st.selectbox("Select the fine-tuned model:", model_options) |
|
|
|
input_tweet = st.text_input("Enter the text below:") |
|
|
|
if st.button("Classify"): |
|
with st.spinner("Classifying..."): |
|
inputs = tokenizer(input_tweet, return_tensors='pt', padding=True, truncation=True,max_length=256) |
|
logits = model(**inputs) |
|
probabilities = torch.softmax(logits, dim=1).tolist()[0] |
|
|
|
label_prob = max(zip(model.config.id2label.values(), probabilities), key=lambda x: x[1]) |
|
|
|
label_map = { |
|
"LABEL_0": "Toxic", |
|
"LABEL_1": "Severe Toxic", |
|
"LABEL_2": "Obscene", |
|
"LABEL_3": "Threat", |
|
"LABEL_4": "Insult", |
|
"LABEL_5": "Identity Hate" |
|
} |
|
print('Insert into table') |
|
st.write(input_tweet) |
|
st.write(label_map[label_prob[0]]) |
|
st.write(label_prob[1]) |
|
st.session_state.df = append_to_dataframe(st.session_state.df,pd.DataFrame({'Tweet': [input_tweet], 'Toxicity Class': [label_map[label_prob[0]]], 'Probability': [label_prob[1]]})) |
|
st.write(st.session_state.df) |
|
|
|
|