Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import joblib | |
import easyocr | |
from transformers import pipeline | |
from keras.utils import load_img | |
from keras.utils import img_to_array | |
from PIL import Image | |
import io | |
from tempfile import NamedTemporaryFile | |
st.set_option('deprecation.showfileUploaderEncoding', False) | |
def load_image(image_file): | |
img = Image.open(image_file) | |
return img | |
def get_img_prediction(imgpath): | |
img = load_img(imgpath,target_size=(128,128,3)) | |
img = img_to_array(img) | |
img = img/255 | |
X_pred_image = np.array(img) | |
X_pred_imaged = X_pred_image.reshape(1,128*128*3) | |
y_pred_pro = loaded_lgbm.predict_proba(X_pred_imaged) | |
return y_pred_pro[0].tolist() | |
def get_text_prediction(imgpath): | |
result = reader.readtext(imgpath,paragraph="False") | |
text = [] | |
for i in result: | |
text.append(i[1]) | |
text = " ".join(text) | |
st.write(text) | |
t_pred = get_inference(text) | |
t_pred_c = [] | |
for c in t_pred: | |
for a in c.values(): | |
if a not in ['NEGATIVE','POSITIVE']: | |
t_pred_c.append(a) | |
return t_pred_c[::-1] | |
def pred_label_mean(i_pred,t_pred): | |
ensemble_pro = [(g + h) / 2 for g, h in zip(i_pred, t_pred)] | |
return ensemble_pro | |
def get_inference(input_text): | |
return bert(input_text) | |
loaded_lgbm = joblib.load('lgbm_v (2).sav') | |
bert = pipeline("text-classification", return_all_scores=True) | |
reader = easyocr.Reader(['en']) | |
st.title('Hateful Memes Classification') | |
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"]) | |
temp_file = NamedTemporaryFile(delete=False) | |
if image_file is not None: | |
# To View Uploaded Image | |
st.write('Meme Image:') | |
temp_file.write(image_file.getvalue()) | |
imgu = load_img(temp_file.name) | |
st.image(imgu) | |
with st.spinner('Predicting Label..'): | |
i_pred = get_img_prediction(temp_file.name) | |
t_pred = get_text_prediction(temp_file.name) | |
y_pred_both = pred_label_mean(i_pred,t_pred) | |
y_pred = y_pred_both.index(max(y_pred_both)) | |
st.write(np.round(np.array(y_pred_both),4)) | |
if y_pred == 0: | |
st.success('Predicted Label: non-hateful meme') | |
if y_pred == 1: | |
st.success('Predicted Label: hateful meme') | |