FTLM / app.py
Jvs37's picture
Update app.py
2da1185
import pandas as pd
import os
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
import streamlit as st
if 'df' not in st.session_state:
st.session_state.df = pd.DataFrame(columns=['Tweet', 'Toxicity Class', 'Probability'])
from torch.nn import BCEWithLogitsLoss
class CustomBertForSequenceClassification(BertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.loss_fct = BCEWithLogitsLoss()
def forward(self, input_ids, attention_mask, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:]
if labels is not None:
labels = labels.to(dtype=torch.float32)
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
outputs = (loss,) + outputs
return outputs
else:
return logits
def load_data(file_path):
datA = pd.read_csv(file_path)
return datA
def tokenize_data(data, tokenizer):
return tokenizer(data['comment_text'].tolist(), padding=True, truncation=True, max_length=256, return_tensors='pt')
class ToxicDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float32)
return item
def __len__(self):
return len(self.labels)
def train_model(model, tokenizer, dataset):
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=48,
logging_dir='./logs',
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
@st.cache(allow_output_mutation=True)
def append_to_dataframe(df,append_row):
print('called append_to_dataframe')
df = pd.concat([df,append_row], ignore_index=True)
return df
@st.cache(allow_output_mutation=True)
def load_and_train_model():
model_save_path = './fine_tuned_model'
tokenizer_save_path = './fine_tuned_tokenizer'
if os.path.exists(model_save_path):
print('loading existing model')
model = CustomBertForSequenceClassification.from_pretrained(model_save_path)
tokenizer = BertTokenizer.from_pretrained(tokenizer_save_path)
return model, tokenizer
print("Loading dataset...")
file_path = r'train.csv'
data = load_data(file_path)
labels = data.iloc[:, 2:].values.tolist()
print("Tokenize.")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=6)
config.hidden_size = 128
config.num_attention_heads = 2
config.intermediate_size = 512
config.num_hidden_layers = 2
model = CustomBertForSequenceClassification(config)
print("Fine-tuning BERT model...")
print('tokenizing')
encodings = tokenize_data(data, tokenizer)
print('dataset=')
dataset = ToxicDataset(encodings, labels)
print('starting training...')
train_model(model, tokenizer, dataset)
print('saving...')
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(tokenizer_save_path)
return model, tokenizer
st.title("Toxic Tweet Classifier")
model, tokenizer = load_and_train_model()
model_options = ['BERT Fine-Tuned']
selected_model = st.selectbox("Select the fine-tuned model:", model_options)
input_tweet = st.text_input("Enter the text below:")
if st.button("Classify"):
with st.spinner("Classifying..."):
inputs = tokenizer(input_tweet, return_tensors='pt', padding=True, truncation=True,max_length=256)
logits = model(**inputs)
probabilities = torch.softmax(logits, dim=1).tolist()[0]
label_prob = max(zip(model.config.id2label.values(), probabilities), key=lambda x: x[1])
label_map = {
"LABEL_0": "Toxic",
"LABEL_1": "Severe Toxic",
"LABEL_2": "Obscene",
"LABEL_3": "Threat",
"LABEL_4": "Insult",
"LABEL_5": "Identity Hate"
}
print('Insert into table')
st.write(input_tweet)
st.write(label_map[label_prob[0]])
st.write(label_prob[1])
st.session_state.df = append_to_dataframe(st.session_state.df,pd.DataFrame({'Tweet': [input_tweet], 'Toxicity Class': [label_map[label_prob[0]]], 'Probability': [label_prob[1]]}))
st.write(st.session_state.df)