|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from unidecode import unidecode |
|
import tensorflow as tf |
|
import cloudpickle |
|
from transformers import DistilBertTokenizerFast |
|
import os |
|
from matplotlib import pyplot as plt |
|
from PIL import Image |
|
|
|
|
|
with open(os.path.join("models", "toxic_comment_preprocessor_classnames.bin"), "rb") as model_file_obj: |
|
text_preprocessor, class_names = cloudpickle.load(model_file_obj) |
|
interpreter = tf.lite.Interpreter(model_path=os.path.join("models", "toxic_comment_classifier_hf_distilbert.tflite")) |
|
|
|
def sigmoid(x): |
|
return 1 / (1 + np.exp(-x)) |
|
|
|
def inference(text): |
|
text = text_preprocessor.preprocess(pd.Series(text))[0] |
|
model_checkpoint = "distilbert-base-uncased" |
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) |
|
tokens = tokenizer(text, max_length=512, padding="max_length", truncation=True, return_tensors="tf") |
|
|
|
|
|
interpreter.allocate_tensors() |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details()[0] |
|
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] |
|
interpreter.set_tensor(input_details[0]["index"], attention_mask) |
|
interpreter.set_tensor(input_details[1]["index"], input_ids) |
|
interpreter.invoke() |
|
tflite_logits = interpreter.get_tensor(output_details["index"])[0] |
|
tflite_pred = sigmoid(tflite_logits) |
|
|
|
result_df = pd.DataFrame({'class': class_names, 'prob': tflite_pred}) |
|
result_df.sort_values(by='prob', ascending=True, inplace=True) |
|
return result_df |
|
|
|
|
|
def display_image(df): |
|
fig, ax = plt.subplots(figsize=(2, 1.8)) |
|
df.plot(x='class', y='prob', kind='barh', ax=ax, color='black', ylabel='') |
|
ax.tick_params(axis='both', which='major', labelsize=8.5) |
|
ax.get_legend().remove() |
|
ax.spines['top'].set_visible(False) |
|
ax.spines['right'].set_visible(False) |
|
ax.spines['bottom'].set_visible(False) |
|
ax.spines['left'].set_visible(False) |
|
ax.get_xaxis().set_ticks([]) |
|
plt.rcParams["figure.autolayout"] = True |
|
plt.xlim(0, 1) |
|
for n, i in enumerate([*df['prob']]): |
|
plt.text(i+0.015, n-0.15, f'{str(np.round(i, 3))} ', fontsize=7.5) |
|
|
|
fig.savefig("prediction.png", bbox_inches='tight', dpi=100) |
|
image = Image.open('prediction.png') |
|
st.write('') |
|
st.image(image, output_format="PNG", caption="Prediction") |
|
|
|
|
|
def main(): |
|
st.title("Toxic Comment Classifier") |
|
comment_txt = st.text_area("Enter a comment:", "", height=100) |
|
if st.button("Submit"): |
|
df = inference(comment_txt) |
|
display_image(df) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |