|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel |
|
import torch |
|
from torch import cuda |
|
from torch.utils.data import Dataset, DataLoader |
|
import finetuning |
|
from finetuning import CustomDistilBertClass |
|
|
|
|
|
|
|
model_map = { |
|
'BERT': 'bert-base-uncased', |
|
'RoBERTa': 'roberta-base', |
|
'DistilBERT': 'distilbert-base-uncased' |
|
} |
|
|
|
|
|
model_options = list(model_map.keys()) |
|
|
|
|
|
train_df = pd.read_csv('train.csv') |
|
train_df = train_df.sample(n=256) |
|
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
"""Load pretrained BERT model.""" |
|
path = "finetuned_model.pt" |
|
model = torch.load(path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_map[model_name]) |
|
return model, tokenizer |
|
|
|
def classify_text(model, tokenizer, text): |
|
"""Classify text using pretrained BERT model.""" |
|
inputs = tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=512, |
|
padding='max_length', |
|
return_tensors='pt', |
|
truncation=True |
|
) |
|
print(inputs) |
|
with torch.no_grad(): |
|
logits = model(inputs['input_ids'],inputs['attention_mask'])[0] |
|
probabilities = torch.softmax(logits, dim=1)[0] |
|
pred_class = torch.argmax(probabilities, dim=0) |
|
print(f"pred class: {pred_class}") |
|
print(probabilities[0].tolist()) |
|
return label_cols[pred_class], round(probabilities[0].tolist(),2) |
|
|
|
|
|
st.title('Toxic Comment Classifier') |
|
model_name = st.sidebar.selectbox('Select a model', model_options) |
|
st.sidebar.write('Selected:', model_name) |
|
model, tokenizer = load_model(model_name) |
|
print(type(model)) |
|
|
|
|
|
st.subheader('Enter comment below:') |
|
text_input = st.text_area(label='', height=100, max_chars=500) |
|
|
|
|
|
if st.button('Classify Toxicity'): |
|
if not text_input: |
|
st.write('Please enter comment') |
|
else: |
|
class_label, class_prob = classify_text(model, tokenizer, text_input) |
|
st.subheader('Results') |
|
st.write('Tweet:', text_input) |
|
st.write('Highest Toxicity Class:', class_label) |
|
st.write('Probability:', class_prob) |
|
|
|
|
|
st.subheader('Toxic Classification Results') |
|
if 'classification_results' not in st.session_state: |
|
st.session_state.classification_results = pd.DataFrame(columns=['tweet', 'toxicity_class', 'probability']) |
|
if st.button('Add to Results'): |
|
if not text_input: |
|
st.write('Please enter comment') |
|
else: |
|
class_label, class_prob = classify_text(model, tokenizer, text_input) |
|
st.subheader('Results') |
|
st.write('Tweet:', text_input) |
|
st.write('Highest Toxicity Class:', class_label) |
|
st.write('Probability:', class_prob) |
|
st.session_state.classification_results = st.session_state.classification_results.append({ |
|
'tweet': text_input, |
|
'toxicity_class': class_label, |
|
'probability': class_prob |
|
}, ignore_index=True) |
|
st.write(st.session_state.classification_results) |