Spaces:
Sleeping
Sleeping
Update Code to Enable model selection
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ from huggingface_hub import snapshot_download
|
|
11 |
import os
|
12 |
import gradio as gr
|
13 |
|
14 |
-
MODEL_REPO = 'sokonana/it107model'
|
15 |
PATH_TO_LABELS = 'data/label_map.pbtxt'
|
16 |
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
|
17 |
|
@@ -29,25 +29,24 @@ def load_image_into_numpy_array(path):
|
|
29 |
return pil_image_as_numpy_array(image)
|
30 |
|
31 |
def load_model():
|
32 |
-
# wget.download("https://nyp-aicourse.s3-ap-southeast-1.amazonaws.com/pretrained-models/balloon_model.tar.gz")
|
33 |
-
# tarfile.open("balloon_model.tar.gz").extractall()
|
34 |
model_path = snapshot_download(MODEL_REPO)
|
35 |
|
36 |
model_dir = os.path.join(model_path, 'saved_model')
|
37 |
detection_model = tf.saved_model.load(model_dir)
|
38 |
return detection_model
|
39 |
|
40 |
-
# samples_folder = 'test_samples
|
41 |
-
# image_path = 'test_samples/sample_balloon.jpeg
|
42 |
-
#
|
43 |
|
44 |
-
def predict(pilimg):
|
|
|
|
|
|
|
45 |
|
46 |
image_np = pil_image_as_numpy_array(pilimg)
|
47 |
return predict2(image_np)
|
48 |
|
49 |
def predict2(image_np):
|
50 |
|
|
|
51 |
results = detection_model(image_np)
|
52 |
|
53 |
# different object detection models have additional results
|
@@ -72,7 +71,7 @@ def predict2(image_np):
|
|
72 |
|
73 |
return result_pil_img
|
74 |
|
75 |
-
detection_model = load_model()
|
76 |
# pil_image = Image.open(image_path)
|
77 |
# image_arr = pil_image_as_numpy_array(pil_image)
|
78 |
|
@@ -80,6 +79,6 @@ detection_model = load_model()
|
|
80 |
# predicted_img.save('predicted.jpg')
|
81 |
|
82 |
gr.Interface(fn=predict,
|
83 |
-
inputs=gr.Image(type="pil"),
|
84 |
outputs=gr.Image(type="pil")
|
85 |
).launch(share=True)
|
|
|
11 |
import os
|
12 |
import gradio as gr
|
13 |
|
14 |
+
MODEL_REPO = 'sokonana/it107model' # default model selected
|
15 |
PATH_TO_LABELS = 'data/label_map.pbtxt'
|
16 |
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
|
17 |
|
|
|
29 |
return pil_image_as_numpy_array(image)
|
30 |
|
31 |
def load_model():
|
|
|
|
|
32 |
model_path = snapshot_download(MODEL_REPO)
|
33 |
|
34 |
model_dir = os.path.join(model_path, 'saved_model')
|
35 |
detection_model = tf.saved_model.load(model_dir)
|
36 |
return detection_model
|
37 |
|
|
|
|
|
|
|
38 |
|
39 |
+
def predict(pilimg, model):
|
40 |
+
|
41 |
+
if model == 'ResNet':
|
42 |
+
MODEL_REPO = 'sokonana/it107model2'
|
43 |
|
44 |
image_np = pil_image_as_numpy_array(pilimg)
|
45 |
return predict2(image_np)
|
46 |
|
47 |
def predict2(image_np):
|
48 |
|
49 |
+
detection_model = load_model()
|
50 |
results = detection_model(image_np)
|
51 |
|
52 |
# different object detection models have additional results
|
|
|
71 |
|
72 |
return result_pil_img
|
73 |
|
74 |
+
# detection_model = load_model()
|
75 |
# pil_image = Image.open(image_path)
|
76 |
# image_arr = pil_image_as_numpy_array(pil_image)
|
77 |
|
|
|
79 |
# predicted_img.save('predicted.jpg')
|
80 |
|
81 |
gr.Interface(fn=predict,
|
82 |
+
inputs=[gr.Image(type="pil"), gr.Radio(['MobileNet', 'ResNet'],label='Model Selection')],
|
83 |
outputs=gr.Image(type="pil")
|
84 |
).launch(share=True)
|