trial / app.py
ynp3's picture
Update app.py
7c15472
raw
history blame
3.37 kB
import streamlit as st
import pandas as pd
import numpy as np
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel
import torch
from torch import cuda
from torch.utils.data import Dataset, DataLoader
import finetuning
from finetuning import CustomDistilBertClass
# device = 'cuda' if cuda.is_available() else 'cpu'
# Load pretrained models
model_map = {
'BERT': 'bert-base-uncased',
'RoBERTa': 'roberta-base',
'DistilBERT': 'distilbert-base-uncased'
}
# Load dropdown options
model_options = list(model_map.keys())
# Load dataset
train_df = pd.read_csv('train.csv')
train_df = train_df.sample(n=256)
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
@st.cache_resource
def load_model(model_name):
"""Load pretrained BERT model."""
path = "finetuned_model.pt"
model = torch.load(path)
tokenizer = AutoTokenizer.from_pretrained(model_map[model_name])
return model, tokenizer
def classify_text(model, tokenizer, text):
"""Classify text using pretrained BERT model."""
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=512,
padding='max_length',
return_tensors='pt',
truncation=True
)
print(inputs)
with torch.no_grad():
logits = model(inputs['input_ids'],inputs['attention_mask'])[0]
probabilities = torch.softmax(logits, dim=1)[0]
pred_class = torch.argmax(probabilities, dim=0)
print(f"pred class: {pred_class}")
print(probabilities[0].tolist())
return label_cols[pred_class], round(probabilities[0].tolist(),2)
# Set up streamlit app
st.title('Toxic Comment Classifier')
model_name = st.sidebar.selectbox('Select a model', model_options)
st.sidebar.write('Selected:', model_name)
model, tokenizer = load_model(model_name)
print(type(model))
# Define input text area
st.subheader('Enter comment below:')
text_input = st.text_area(label='', height=100, max_chars=500)
# Make prediction when user clicks 'Classify' button
if st.button('Classify Toxicity'):
if not text_input:
st.write('Please enter comment')
else:
class_label, class_prob = classify_text(model, tokenizer, text_input)
st.subheader('Results')
st.write('Tweet:', text_input)
st.write('Highest Toxicity Class:', class_label)
st.write('Probability:', class_prob)
# Display table of results
st.subheader('Toxic Classification Results')
if 'classification_results' not in st.session_state:
st.session_state.classification_results = pd.DataFrame(columns=['tweet', 'toxicity_class', 'probability'])
if st.button('Add to Results'):
if not text_input:
st.write('Please enter comment')
else:
class_label, class_prob = classify_text(model, tokenizer, text_input)
st.subheader('Results')
st.write('Tweet:', text_input)
st.write('Highest Toxicity Class:', class_label)
st.write('Probability:', class_prob)
st.session_state.classification_results = st.session_state.classification_results.append({
'tweet': text_input,
'toxicity_class': class_label,
'probability': class_prob
}, ignore_index=True)
st.write(st.session_state.classification_results)