sokonana commited on
Commit
22a5acd
1 Parent(s): aa5c73f

Update Code to Enable model selection

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -11,7 +11,7 @@ from huggingface_hub import snapshot_download
11
  import os
12
  import gradio as gr
13
 
14
- MODEL_REPO = 'sokonana/it107model'
15
  PATH_TO_LABELS = 'data/label_map.pbtxt'
16
  category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
17
 
@@ -29,25 +29,24 @@ def load_image_into_numpy_array(path):
29
  return pil_image_as_numpy_array(image)
30
 
31
  def load_model():
32
- # wget.download("https://nyp-aicourse.s3-ap-southeast-1.amazonaws.com/pretrained-models/balloon_model.tar.gz")
33
- # tarfile.open("balloon_model.tar.gz").extractall()
34
  model_path = snapshot_download(MODEL_REPO)
35
 
36
  model_dir = os.path.join(model_path, 'saved_model')
37
  detection_model = tf.saved_model.load(model_dir)
38
  return detection_model
39
 
40
- # samples_folder = 'test_samples
41
- # image_path = 'test_samples/sample_balloon.jpeg
42
- #
43
 
44
- def predict(pilimg):
 
 
 
45
 
46
  image_np = pil_image_as_numpy_array(pilimg)
47
  return predict2(image_np)
48
 
49
  def predict2(image_np):
50
 
 
51
  results = detection_model(image_np)
52
 
53
  # different object detection models have additional results
@@ -72,7 +71,7 @@ def predict2(image_np):
72
 
73
  return result_pil_img
74
 
75
- detection_model = load_model()
76
  # pil_image = Image.open(image_path)
77
  # image_arr = pil_image_as_numpy_array(pil_image)
78
 
@@ -80,6 +79,6 @@ detection_model = load_model()
80
  # predicted_img.save('predicted.jpg')
81
 
82
  gr.Interface(fn=predict,
83
- inputs=gr.Image(type="pil"),
84
  outputs=gr.Image(type="pil")
85
  ).launch(share=True)
 
11
  import os
12
  import gradio as gr
13
 
14
+ MODEL_REPO = 'sokonana/it107model' # default model selected
15
  PATH_TO_LABELS = 'data/label_map.pbtxt'
16
  category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
17
 
 
29
  return pil_image_as_numpy_array(image)
30
 
31
  def load_model():
 
 
32
  model_path = snapshot_download(MODEL_REPO)
33
 
34
  model_dir = os.path.join(model_path, 'saved_model')
35
  detection_model = tf.saved_model.load(model_dir)
36
  return detection_model
37
 
 
 
 
38
 
39
+ def predict(pilimg, model):
40
+
41
+ if model == 'ResNet':
42
+ MODEL_REPO = 'sokonana/it107model2'
43
 
44
  image_np = pil_image_as_numpy_array(pilimg)
45
  return predict2(image_np)
46
 
47
  def predict2(image_np):
48
 
49
+ detection_model = load_model()
50
  results = detection_model(image_np)
51
 
52
  # different object detection models have additional results
 
71
 
72
  return result_pil_img
73
 
74
+ # detection_model = load_model()
75
  # pil_image = Image.open(image_path)
76
  # image_arr = pil_image_as_numpy_array(pil_image)
77
 
 
79
  # predicted_img.save('predicted.jpg')
80
 
81
  gr.Interface(fn=predict,
82
+ inputs=[gr.Image(type="pil"), gr.Radio(['MobileNet', 'ResNet'],label='Model Selection')],
83
  outputs=gr.Image(type="pil")
84
  ).launch(share=True)