ksvmuralidhar's picture
Create app.py
8ae18de
raw
history blame
No virus
2.69 kB
import streamlit as st
import pandas as pd
import numpy as np
from unidecode import unidecode
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
from matplotlib import pyplot as plt
from PIL import Image
with open(os.path.join("models", "toxic_comment_preprocessor_classnames.bin"), "rb") as model_file_obj:
text_preprocessor, class_names = cloudpickle.load(model_file_obj)
interpreter = tf.lite.Interpreter(model_path=os.path.join("models", "toxic_comment_classifier_hf_distilbert.tflite"))
def inference(text):
text = text_preprocessor.preprocess(pd.Series(text))[0]
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
tokens = tokenizer(text, max_length=512, padding="max_length", truncation=True, return_tensors="tf")
# tflite model inference
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()[0]
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
interpreter.set_tensor(input_details[0]["index"], attention_mask)
interpreter.set_tensor(input_details[1]["index"], input_ids)
interpreter.invoke()
tflite_pred = interpreter.get_tensor(output_details["index"])[0]
result_df = pd.DataFrame({'class': class_names, 'prob': tflite_pred})
result_df.sort_values(by='prob', ascending=True, inplace=True)
return result_df
def display_image(df):
fig, ax = plt.subplots(figsize=(2, 1.8))
df.plot(x='class', y='prob', kind='barh', ax=ax, color='black', ylabel='')
ax.tick_params(axis='both', which='major', labelsize=8.5)
ax.get_legend().remove()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.get_xaxis().set_ticks([])
plt.rcParams["figure.autolayout"] = True
plt.xlim(0, 1)
for n, i in enumerate([*df['prob']]):
plt.text(i+0.015, n-0.15, f'{str(np.round(i, 3))} ', fontsize=7.5)
fig.savefig("prediction.png", bbox_inches='tight', dpi=100)
image = Image.open('prediction.png')
st.write('')
st.image(image, output_format="PNG", caption="Prediction")
############## ENTRY POINT START #######################
def main():
st.title("Toxic Comment Classifier")
comment_txt = st.text_area("Enter a comment:", "", height=100)
if st.button("Submit"):
df = inference(comment_txt)
display_image(df)
############## ENTRY POINT END #######################
if __name__ == "__main__":
main()