|
import warnings |
|
warnings.filterwarnings("ignore") |
|
import streamlit as st |
|
import pandas as pd |
|
import plotly.express as px |
|
from src import * |
|
import os |
|
|
|
global model |
|
global prediction |
|
|
|
@st.cache_resource |
|
def model_obj(): |
|
model = ModelLoader() |
|
prediction = PredictionServices(model.Model, model.Tokenizer) |
|
st.image(os.path.join("img","toxic.jpg")) |
|
return prediction |
|
|
|
prediction = model_obj() |
|
|
|
def single_predict(text): |
|
preds = prediction.single_predict(text) |
|
if preds < 0.5: |
|
st.success(f'Non Toxic Comment!!! :thumbsup:') |
|
else: |
|
st.error(f'Toxic Comment!!! :thumbsdown:') |
|
prediction.plot(preds) |
|
|
|
def batch_predict(data): |
|
preds = prediction.batch_predict(data) |
|
return preds.to_csv(index=False).encode('utf-8') |
|
|
|
st.title('Toxic Comment Classifier') |
|
st.write("This application will help to classify any comment or text in any language into 'TOXIC' or 'NON-TOXIC'") |
|
tab1, tab2 = st.tabs(["Single Value Prediciton","Batch Prediction"]) |
|
|
|
with tab1: |
|
st.subheader("Prediction") |
|
with st.form("comment_form", clear_on_submit=True): |
|
comment = st.text_area(label="Enter your comment") |
|
button = st.form_submit_button(label="Predict") |
|
if button: |
|
with st.spinner('Please Wait!!! Prediction in process....'): |
|
single_predict(comment) |
|
|
|
with tab2: |
|
st.subheader("Batch Prediction") |
|
csv_file = st.file_uploader("Upload File",type=['csv']) |
|
|
|
if csv_file is not None: |
|
csv = batch_predict(csv_file) |
|
st.download_button( |
|
label="Download", |
|
data=csv, |
|
file_name='prediction.csv', |
|
mime='text/csv', |
|
) |
|
|
|
|
|
|
|
|