datasciencedojo's picture
submit button hover color updated
164c6a0
from PIL import Image, ImageOps
import numpy as np
from collections import OrderedDict
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from keras.models import load_model
import gradio as gr
def create_plot(data):
sns.set_theme(style="whitegrid")
f, ax = plt.subplots(figsize=(5, 5))
sns.set_color_codes("pastel")
sns.barplot(x="Total", y="Labels", data=data,label="Total", color="b")
sns.set_color_codes("muted")
sns.barplot(x="Confidence Score", y="Labels", data=data,label="Conficence Score", color="b")
ax.legend(ncol=2, loc="lower right", frameon=True)
sns.despine(left=True, bottom=True)
return f
def predict_pneumonia(img):
np.set_printoptions(suppress=True)
model = load_model('keras_model.h5', compile=False)
class_names = open('labels.txt', 'r').readlines()
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
# image = Image.open(img).convert('RGB')
image = img
size = (224, 224)
image_PIL = Image.fromarray(image)
image = ImageOps.fit(image_PIL, size, Image.LANCZOS)
image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
data[0] = normalized_image_array
prediction = model.predict(data)
index = np.argmax(prediction)
class_name = class_names[index]
confidence_score = prediction[0][index]
c_name = (class_name[2:])[:-1]
if c_name == "Normal":
pneumonia_prediction = "Chest XRay is normal no signs of pneumonia"
other_class = "Pneumonia"
else:
other_class = "Normal"
pneumonia_prediction = "Chest XRay shows signs of pneumonia"
res = {"Labels":[c_name,other_class], "Confidence Score":[(confidence_score*100),(1-confidence_score)*100],"Total":100}
data_for_plot = pd.DataFrame.from_dict(res)
pneumonia_conf_plt = create_plot(data_for_plot)
return pneumonia_prediction,pneumonia_conf_plt
css = """
footer {display:none !important}
.output-markdown{display:none !important}
footer {visibility: hidden}
.hover\:bg-orange-50:hover {
--tw-bg-opacity: 1 !important;
background-color: rgb(229,225,255) !important;
}
img.gr-sample-image:hover, video.gr-sample-video:hover {
--tw-border-opacity: 1;
border-color: rgb(37, 56, 133) !important;
}
.gr-button-lg {
z-index: 14;
width: 113px;
height: 30px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(17, 20, 45) !important;
border: none !important;
text-align: center !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 6px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: none !important;
}
.gr-button-lg:hover{
z-index: 14;
width: 113px;
height: 30px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(66, 133, 244) !important;
border: none !important;
text-align: center !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 6px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
}
"""
with gr.Blocks(title="Pneumonia Detection | Data Science Dojo", css = css) as demo:
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
imgInput = gr.Image()
with gr.Column(scale=1):
pneumonia = gr.Textbox(label='Presence of pneumonia')
plot = gr.Plot(label="Plot")
submit_button = gr.Button(value="Submit")
submit_button.click(fn=predict_pneumonia, inputs=[imgInput], outputs=[pneumonia,plot])
gr.Examples(
examples=["normal_Sample.jpg","pneumonia_sample.jpg"],
inputs=imgInput,
outputs=[pneumonia,plot],
fn=predict_pneumonia,
cache_examples=True,
)
demo.launch()