File size: 2,831 Bytes
1179283
 
 
 
a83d91e
 
32cb20d
a83d91e
8c62505
a83d91e
 
 
3da4461
a83d91e
 
 
 
 
 
e092e8d
d55ab02
e092e8d
 
 
 
 
 
 
 
 
 
 
a83d91e
1179283
 
 
 
5ea968c
 
1179283
bd4d9f6
1179283
 
 
 
 
2e6883a
1179283
 
e092e8d
 
 
 
 
 
 
 
 
 
 
12d8e70
 
 
e092e8d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# %%
import gradio as gr
import tensorflow as tf
import cv2
import os

model_folder = 'model'
destination = model_folder
repo_url = "https://huggingface.co/RandomCatLover/plants_disease"

if not os.path.exists(destination):
    import subprocess
    #repo_url = os.getenv("GIT_CORE")
    command = f'git clone {repo_url} {destination}'
    try:
        subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env)
        print('Repository cloned successfully.')
    except subprocess.CalledProcessError as e:
        print(f'Error cloning repository: {e.output.decode()}')

destination = 'explainer_tf_mobilenetv2'
if not os.path.exists(destination):
    import subprocess
    repo_url = os.getenv("GIT_CORE")
    command = f'git clone {repo_url}'
    try:
        subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env)
        print('Repository cloned successfully.')
    except subprocess.CalledProcessError as e:
        print(f'Error cloning repository: {e.output.decode()}')

from explainer_tf_mobilenetv2.explainer import explainer
# %%
with open(f'{model_folder}/labels.txt', 'r') as f:
  labels = f.read().split('\n')

# model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5')
model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5')
#model = tf.keras.models.load_model(f'{model_folder}/MobileNetV2_last_layer.hdf5')
# %%
def classify_image(inp):
  inp = cv2.resize(inp, (224,224,))
  inp = inp.reshape((-1, 224, 224, 3))
  inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
  prediction = model.predict(inp).flatten()
  print(prediction)
  confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}
  return confidences

def explainer_wrapper(inp):
  return explainer(inp, model)

with gr.Blocks() as demo:
  with gr.Column():
    with gr.Row():
        with gr.Column():
          image = gr.inputs.Image(shape=(224, 224))
          with gr.Row():
            classify = gr.Button("Classify")
            interpret = gr.Button("Interpret")
        with gr.Column():
            label = gr.outputs.Label(num_top_classes=3)
            interpretation = gr.Plot(label="Interpretation")
        # interpretation = gr.outputs.Image(type="numpy", label="Interpretation")
  gr.Examples(["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"],
              inputs=[image],)
  classify.click(classify_image, image, label, queue=True)
  interpret.click(explainer_wrapper, image, interpretation, queue=True)
  

demo.queue(concurrency_count=3).launch()         
#%%
# gr.Interface(fn=classify_image, 
#              inputs=gr.Image(shape=(224, 224)),
#              outputs=gr.Label(num_top_classes=3),
#              examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"]).launch()