plants_disease / app.py
RandomCatLover's picture
Update app.py
12d8e70
raw
history blame contribute delete
No virus
2.83 kB
# %%
import gradio as gr
import tensorflow as tf
import cv2
import os
model_folder = 'model'
destination = model_folder
repo_url = "https://huggingface.co/RandomCatLover/plants_disease"
if not os.path.exists(destination):
import subprocess
#repo_url = os.getenv("GIT_CORE")
command = f'git clone {repo_url} {destination}'
try:
subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env)
print('Repository cloned successfully.')
except subprocess.CalledProcessError as e:
print(f'Error cloning repository: {e.output.decode()}')
destination = 'explainer_tf_mobilenetv2'
if not os.path.exists(destination):
import subprocess
repo_url = os.getenv("GIT_CORE")
command = f'git clone {repo_url}'
try:
subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env)
print('Repository cloned successfully.')
except subprocess.CalledProcessError as e:
print(f'Error cloning repository: {e.output.decode()}')
from explainer_tf_mobilenetv2.explainer import explainer
# %%
with open(f'{model_folder}/labels.txt', 'r') as f:
labels = f.read().split('\n')
# model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5')
model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5')
#model = tf.keras.models.load_model(f'{model_folder}/MobileNetV2_last_layer.hdf5')
# %%
def classify_image(inp):
inp = cv2.resize(inp, (224,224,))
inp = inp.reshape((-1, 224, 224, 3))
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
prediction = model.predict(inp).flatten()
print(prediction)
confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}
return confidences
def explainer_wrapper(inp):
return explainer(inp, model)
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
with gr.Column():
image = gr.inputs.Image(shape=(224, 224))
with gr.Row():
classify = gr.Button("Classify")
interpret = gr.Button("Interpret")
with gr.Column():
label = gr.outputs.Label(num_top_classes=3)
interpretation = gr.Plot(label="Interpretation")
# interpretation = gr.outputs.Image(type="numpy", label="Interpretation")
gr.Examples(["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"],
inputs=[image],)
classify.click(classify_image, image, label, queue=True)
interpret.click(explainer_wrapper, image, interpretation, queue=True)
demo.queue(concurrency_count=3).launch()
#%%
# gr.Interface(fn=classify_image,
# inputs=gr.Image(shape=(224, 224)),
# outputs=gr.Label(num_top_classes=3),
# examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"]).launch()