File size: 5,031 Bytes
e66572a 2da1185 e66572a 2da1185 e66572a 2da1185 e66572a 2da1185 e66572a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import pandas as pd
import os
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
import streamlit as st
if 'df' not in st.session_state:
st.session_state.df = pd.DataFrame(columns=['Tweet', 'Toxicity Class', 'Probability'])
from torch.nn import BCEWithLogitsLoss
class CustomBertForSequenceClassification(BertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.loss_fct = BCEWithLogitsLoss()
def forward(self, input_ids, attention_mask, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:]
if labels is not None:
labels = labels.to(dtype=torch.float32)
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
outputs = (loss,) + outputs
return outputs
else:
return logits
def load_data(file_path):
datA = pd.read_csv(file_path)
return datA
def tokenize_data(data, tokenizer):
return tokenizer(data['comment_text'].tolist(), padding=True, truncation=True, max_length=256, return_tensors='pt')
class ToxicDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float32)
return item
def __len__(self):
return len(self.labels)
def train_model(model, tokenizer, dataset):
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=48,
logging_dir='./logs',
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
@st.cache(allow_output_mutation=True)
def append_to_dataframe(df,append_row):
print('called append_to_dataframe')
df = pd.concat([df,append_row], ignore_index=True)
return df
@st.cache(allow_output_mutation=True)
def load_and_train_model():
model_save_path = './fine_tuned_model'
tokenizer_save_path = './fine_tuned_tokenizer'
if os.path.exists(model_save_path):
print('loading existing model')
model = CustomBertForSequenceClassification.from_pretrained(model_save_path)
tokenizer = BertTokenizer.from_pretrained(tokenizer_save_path)
return model, tokenizer
print("Loading dataset...")
file_path = r'train.csv'
data = load_data(file_path)
labels = data.iloc[:, 2:].values.tolist()
print("Tokenize.")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=6)
config.hidden_size = 128
config.num_attention_heads = 2
config.intermediate_size = 512
config.num_hidden_layers = 2
model = CustomBertForSequenceClassification(config)
print("Fine-tuning BERT model...")
print('tokenizing')
encodings = tokenize_data(data, tokenizer)
print('dataset=')
dataset = ToxicDataset(encodings, labels)
print('starting training...')
train_model(model, tokenizer, dataset)
print('saving...')
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(tokenizer_save_path)
return model, tokenizer
st.title("Toxic Tweet Classifier")
model, tokenizer = load_and_train_model()
model_options = ['BERT Fine-Tuned']
selected_model = st.selectbox("Select the fine-tuned model:", model_options)
input_tweet = st.text_input("Enter the text below:")
if st.button("Classify"):
with st.spinner("Classifying..."):
inputs = tokenizer(input_tweet, return_tensors='pt', padding=True, truncation=True,max_length=256)
logits = model(**inputs)
probabilities = torch.softmax(logits, dim=1).tolist()[0]
label_prob = max(zip(model.config.id2label.values(), probabilities), key=lambda x: x[1])
label_map = {
"LABEL_0": "Toxic",
"LABEL_1": "Severe Toxic",
"LABEL_2": "Obscene",
"LABEL_3": "Threat",
"LABEL_4": "Insult",
"LABEL_5": "Identity Hate"
}
print('Insert into table')
st.write(input_tweet)
st.write(label_map[label_prob[0]])
st.write(label_prob[1])
st.session_state.df = append_to_dataframe(st.session_state.df,pd.DataFrame({'Tweet': [input_tweet], 'Toxicity Class': [label_map[label_prob[0]]], 'Probability': [label_prob[1]]}))
st.write(st.session_state.df)
|