Spaces:
Sleeping
Sleeping
# %% | |
import gradio as gr | |
import tensorflow as tf | |
import cv2 | |
import os | |
model_folder = 'model' | |
destination = model_folder | |
repo_url = "https://huggingface.co/RandomCatLover/plants_disease" | |
if not os.path.exists(destination): | |
import subprocess | |
#repo_url = os.getenv("GIT_CORE") | |
command = f'git clone {repo_url} {destination}' | |
try: | |
subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env) | |
print('Repository cloned successfully.') | |
except subprocess.CalledProcessError as e: | |
print(f'Error cloning repository: {e.output.decode()}') | |
destination = 'explainer_tf_mobilenetv2' | |
if not os.path.exists(destination): | |
import subprocess | |
repo_url = os.getenv("GIT_CORE") | |
command = f'git clone {repo_url}' | |
try: | |
subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env) | |
print('Repository cloned successfully.') | |
except subprocess.CalledProcessError as e: | |
print(f'Error cloning repository: {e.output.decode()}') | |
from explainer_tf_mobilenetv2.explainer import explainer | |
# %% | |
with open(f'{model_folder}/labels.txt', 'r') as f: | |
labels = f.read().split('\n') | |
# model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5') | |
model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5') | |
#model = tf.keras.models.load_model(f'{model_folder}/MobileNetV2_last_layer.hdf5') | |
# %% | |
def classify_image(inp): | |
inp = cv2.resize(inp, (224,224,)) | |
inp = inp.reshape((-1, 224, 224, 3)) | |
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp) | |
prediction = model.predict(inp).flatten() | |
print(prediction) | |
confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} | |
return confidences | |
def explainer_wrapper(inp): | |
return explainer(inp, model) | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.inputs.Image(shape=(224, 224)) | |
with gr.Row(): | |
classify = gr.Button("Classify") | |
interpret = gr.Button("Interpret") | |
label = gr.outputs.Label(num_top_classes=3) | |
interpretation = gr.Plot(label="Interpretation") | |
# interpretation = gr.outputs.Image(type="numpy", label="Interpretation") | |
gr.Examples(["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"], | |
inputs=[image],) | |
classify.click(classify_image, image, label, queue=True) | |
interpret.click(explainer_wrapper, image, interpretation, queue=True) | |
demo.queue(concurrency_count=3).launch() | |
#%% | |
# gr.Interface(fn=classify_image, | |
# inputs=gr.Image(shape=(224, 224)), | |
# outputs=gr.Label(num_top_classes=3), | |
# examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"]).launch() | |