Shrikrishna's picture
Update app.py
12e508d
raw history blame
No virus
7.81 kB
import streamlit as st
import pickle
import base64
import json
import numpy as np
import cv2
import pywt
import joblib
from PIL import Image
__class_name_to_number = {}
__class_number_to_name = {}
__model = None
st.header("Welcome to Indian Cricketers Classifier!")
col1,col2,col3, col4 = st.columns(4)
with col1:
#dhoni = cv2.imread("dhoni.jpg")
dhoni = Image.open("dhoni.jpg")
st.image(dhoni,width=150, caption='MS Dhoni')
#ganguly = cv2.imread("ganguly.jpg")
ganguly = Image.open("ganguly.jpg")
st.image(ganguly,width=150, caption='Saurav Ganguly')
with col2:
#rahul = cv2.imread("rahul.jpg")
rahul = Image.open("rahul.jpg")
st.image(rahul,width=150, caption='Rahul Dravid')
#virat = cv2.imread("virat.jpg")
virat = Image.open("virat.jpg")
st.image(virat,width=150, caption='Virat Kohli')
with col3:
#sachin = cv2.imread("sachin.jpg")
sachin = Image.open("sachin.jpg")
st.image(sachin,width=150, caption='Sachin Tendulkar')
#sehwag = cv2.imread("sehwag.jpg")
sehwag = Image.open("sehwag.jpg")
st.image(sehwag,width=150, caption='Virendra Sehwag')
with col4:
sunil_gavaskar = Image.open("sunil_gavaskar.jpg")
st.image(sunil_gavaskar,width=150, caption='Sunil Gavaskar')
#sehwag = cv2.imread("sehwag.jpg")
kapil_dev = Image.open("kapil_dev.jpg")
st.image(kapil_dev,width=150, caption='Kapil Dev')
def classify_image(image_base64_data, file_path=None):
imgs = get_cropped_image_if_2_eyes_new(file_path, image_base64_data)
result = []
for img in imgs:
scalled_raw_img = cv2.resize(img, (32, 32))
img_har = w2d(img, 'db1', 5)
scalled_img_har = cv2.resize(img_har, (32, 32))
combined_img = np.vstack((scalled_raw_img.reshape(32 * 32 * 3, 1), scalled_img_har.reshape(32 * 32, 1)))
len_image_array = 32*32*3 + 32*32
final = combined_img.reshape(1,len_image_array).astype(float)
result.append({
'class': class_number_to_name(__model.predict(final)[0]),
'class_probability': np.around(__model.predict_proba(final)*100,2).tolist()[0],
'class_dictionary': __class_name_to_number
})
return result
def get_cropped_image_if_2_eyes_new(file_path, image_base64_data):
face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')
eye_cascade = cv2.CascadeClassifier('haarcascade_eye.xml')
if file_path:
img = cv2.imread(file_path)
#st.image(img,width=150, caption='Uploaded Image')
else:
img = get_cv2_image_from_base64_string(image_base64_data)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.3, 5)
cropped_faces = []
for (x,y,w,h) in faces:
roi_gray = gray[y:y+h, x:x+w]
roi_color = img[y:y+h, x:x+w]
eyes = eye_cascade.detectMultiScale(roi_gray)
if len(eyes) >= 2:
cropped_faces.append(roi_color)
return cropped_faces
def w2d(img, mode='haar', level=1):
imArray = img
#Datatype conversions
#convert to grayscale
imArray = cv2.cvtColor( imArray,cv2.COLOR_RGB2GRAY )
#convert to float
imArray = np.float32(imArray)
imArray /= 255;
# compute coefficients
coeffs=pywt.wavedec2(imArray, mode, level=level)
#Process Coefficients
coeffs_H=list(coeffs)
coeffs_H[0] *= 0;
# reconstruction
imArray_H=pywt.waverec2(coeffs_H, mode);
imArray_H *= 255;
imArray_H = np.uint8(imArray_H)
return imArray_H
def get_cv2_image_from_base64_string(b64str):
'''
credit: https://stackoverflow.com/questions/33754935/read-a-base-64-encoded-image-from-memory-using-opencv-python-library
:param uri:
:return:
'''
encoded_data = b64str.split(',')[1]
nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return img
def load_saved_artifacts():
#print("loading saved artifacts...start")
global __class_name_to_number
global __class_number_to_name
with open("class_cri_dictionary1.json", "r") as f:
__class_name_to_number = json.load(f)
__class_number_to_name = {v:k for k,v in __class_name_to_number.items()}
global __model
if __model is None:
__model = joblib.load('cri_saved_model1.pkl')
#st.text("loading saved artifacts...done")
return True
def class_number_to_name(class_num):
return __class_number_to_name[class_num]
def get_b64_test_image_for_virat():
with open("b64.txt") as f:
return f.read()
def save_uploaded_image(uploaded_image):
try:
with open(uploaded_image.name, 'wb') as f:
f.write(uploaded_image.getbuffer())
return {"complete":True, "filename":uploaded_image.name}
except:
return {"complete":False, "filename":""}
uploaded_image = st.file_uploader('Choose an image')
if uploaded_image is not None:
# save the image in a directory
image_dict = save_uploaded_image(uploaded_image)
if image_dict["complete"]:
display_image = image_dict["filename"]
st.header("Image Uploded!, Processing...")
if load_saved_artifacts():
img = cv2.imread(display_image)
img = cv2.resize(img, (130, 130))
result = classify_image(get_b64_test_image_for_virat(), display_image)
try:
col6,col7 = st.columns(2)
with col6:
st.header("Uploded Image: ")
dis_img = Image.open(display_image)
st.image(dis_img,width=130, caption='Uploaded Image')
with col7:
celeb = result[0]['class']
st.header("Predicted Image: ")
if celeb == "ms_dhoni":
#dhoni = cv2.imread("dhoni.jpg")
dhoni = Image.open("dhoni.jpg")
st.image(dhoni,width=150, caption='MS Dhoni')
elif celeb == "rahul_dravid":
#dravid = cv2.imread("rahul.jpg")
dravid = Image.open("rahul.jpg")
st.image(dravid,width=150, caption='Rahul Dravid')
elif celeb == "sachin_tendulkar":
#sachin = cv2.imread("sachin.jpg")
sachin = Image.open("sachin.jpg")
st.image(sachin,width=150, caption='Sachin Tendulkar')
elif celeb == "Saurav Ganguly":
#ganguly = cv2.imread("ganguly.jpg")
ganguly = Image.open("ganguly.jpg")
st.image(ganguly,width=150, caption='Saurav Ganguly')
elif celeb == "virat_kohli":
#virat = cv2.imread("virat.jpg")
virat = Image.open("virat.jpg")
st.image(virat,width=150, caption='Virat Kohli')
elif celeb == "Virendra Sehwag":
#sehwag = cv2.imread("sehwag.jpg")
sehwag = Image.open("sehwag.jpg")
st.image(sehwag,width=150, caption='Virendra Sehwag')
elif celeb == "sunil_gavaskar":
sunil_gavaskar = Image.open("sunil_gavaskar.jpg")
st.image(sunil_gavaskar,width=150, caption='Sunil Gavaskar')
elif celeb == "kapil_dev":
kapil_dev = Image.open("kapil_dev.jpg")
st.image(kapil_dev,width=150, caption='Kapil Dev')
except:
st.header("Image Cannot be Classified!Please Try Again")