import pandas as pd import os from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments import torch import streamlit as st if 'df' not in st.session_state: st.session_state.df = pd.DataFrame(columns=['Tweet', 'Toxicity Class', 'Probability']) from torch.nn import BCEWithLogitsLoss class CustomBertForSequenceClassification(BertForSequenceClassification): def __init__(self, config): super().__init__(config) self.loss_fct = BCEWithLogitsLoss() def forward(self, input_ids, attention_mask, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None): outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds) pooled_output = outputs[1] logits = self.classifier(pooled_output) outputs = (logits,) + outputs[2:] if labels is not None: labels = labels.to(dtype=torch.float32) loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) outputs = (loss,) + outputs return outputs else: return logits def load_data(file_path): datA = pd.read_csv(file_path) return datA def tokenize_data(data, tokenizer): return tokenizer(data['comment_text'].tolist(), padding=True, truncation=True, max_length=256, return_tensors='pt') class ToxicDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float32) return item def __len__(self): return len(self.labels) def train_model(model, tokenizer, dataset): training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=48, logging_dir='./logs', ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, ) trainer.train() @st.cache(allow_output_mutation=True) def append_to_dataframe(df,append_row): print('called append_to_dataframe') df = pd.concat([df,append_row], ignore_index=True) return df @st.cache(allow_output_mutation=True) def load_and_train_model(): model_save_path = './fine_tuned_model' tokenizer_save_path = './fine_tuned_tokenizer' if os.path.exists(model_save_path): print('loading existing model') model = CustomBertForSequenceClassification.from_pretrained(model_save_path) tokenizer = BertTokenizer.from_pretrained(tokenizer_save_path) return model, tokenizer print("Loading dataset...") file_path = r'train.csv' data = load_data(file_path) labels = data.iloc[:, 2:].values.tolist() print("Tokenize.") tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') config = BertConfig.from_pretrained('bert-base-uncased', num_labels=6) config.hidden_size = 128 config.num_attention_heads = 2 config.intermediate_size = 512 config.num_hidden_layers = 2 model = CustomBertForSequenceClassification(config) print("Fine-tuning BERT model...") print('tokenizing') encodings = tokenize_data(data, tokenizer) print('dataset=') dataset = ToxicDataset(encodings, labels) print('starting training...') train_model(model, tokenizer, dataset) print('saving...') model.save_pretrained(model_save_path) tokenizer.save_pretrained(tokenizer_save_path) return model, tokenizer st.title("Toxic Tweet Classifier") model, tokenizer = load_and_train_model() model_options = ['BERT Fine-Tuned'] selected_model = st.selectbox("Select the fine-tuned model:", model_options) input_tweet = st.text_input("Enter the text below:") if st.button("Classify"): with st.spinner("Classifying..."): inputs = tokenizer(input_tweet, return_tensors='pt', padding=True, truncation=True,max_length=256) logits = model(**inputs) probabilities = torch.softmax(logits, dim=1).tolist()[0] label_prob = max(zip(model.config.id2label.values(), probabilities), key=lambda x: x[1]) label_map = { "LABEL_0": "Toxic", "LABEL_1": "Severe Toxic", "LABEL_2": "Obscene", "LABEL_3": "Threat", "LABEL_4": "Insult", "LABEL_5": "Identity Hate" } print('Insert into table') st.write(input_tweet) st.write(label_map[label_prob[0]]) st.write(label_prob[1]) st.session_state.df = append_to_dataframe(st.session_state.df,pd.DataFrame({'Tweet': [input_tweet], 'Toxicity Class': [label_map[label_prob[0]]], 'Probability': [label_prob[1]]})) st.write(st.session_state.df)