atwang commited on
Commit
92d915f
1 Parent(s): eb85b0c

update app to improve error handling and allow for simultaneous usage

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +35 -21
.gitignore CHANGED
@@ -5,3 +5,4 @@ venv/
5
  __pycache__/
6
  .output/
7
  .data/
 
 
5
  __pycache__/
6
  .output/
7
  .data/
8
+ .vscode/
app.py CHANGED
@@ -26,22 +26,13 @@ ARGS = SimpleNamespace(
26
  )
27
  NUM_SAMPLES = 10
28
 
29
- outputs = []
30
 
31
 
32
  def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_samples: int) -> list[Any]:
33
  global outputs
34
 
35
- def find_gifs(path: str) -> list[str]:
36
- """Scrape folders for all generated gif files."""
37
- for file in os.listdir(path):
38
- sub_path = os.path.join(path, file)
39
- if os.path.isdir(sub_path):
40
- for image_file in os.listdir(sub_path):
41
- if re.match(r".*\.gif$", image_file):
42
- yield os.path.join(sub_path, image_file)
43
-
44
- def find_images(path: str) -> list[str]:
45
  """Scrape folders for all generated gif files."""
46
  images = {}
47
  for file in os.listdir(path):
@@ -62,6 +53,14 @@ def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_sample
62
  else:
63
  os.remove(full_path)
64
 
 
 
 
 
 
 
 
 
65
  cfg = setup_cfg(ARGS)
66
 
67
  engine.launch(
@@ -80,35 +79,42 @@ def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_sample
80
 
81
  # process output
82
  # TODO: may want to select these in decreasing order of score
 
83
  image_files = find_images(ARGS.output)
84
- outputs = []
85
  for count, part in enumerate(image_files):
86
  if count < MAX_PARTS:
87
- outputs.append([Image.open(im) for im in image_files[part]])
88
 
89
  return [
90
- *[gr.update(value=out[0], visible=True) for out in outputs],
91
  *[gr.update(visible=False) for _ in range(MAX_PARTS - len(outputs))],
92
  ]
93
 
94
 
95
  def get_trigger(idx: int, fps: int = 40, oscillate: bool = True):
96
- def iter_images(*args, **kwargs):
97
- if idx < len(outputs):
98
- for im in outputs[idx]:
 
 
 
99
  time.sleep(1.0 / fps)
100
  yield im
101
  if oscillate:
102
- for im in reversed(outputs[idx]):
103
  time.sleep(1.0 / fps)
104
  yield im
105
 
106
  else:
107
- raise ValueError("Could not find any images to load into this module.")
108
 
109
  return iter_images
110
 
111
 
 
 
 
 
112
  with gr.Blocks() as demo:
113
  gr.Markdown(
114
  """
@@ -176,12 +182,20 @@ with gr.Blocks() as demo:
176
  )
177
 
178
  submit_btn = gr.Button("Run model")
 
179
 
180
  # TODO: do we want to set a maximum limit on how many parts we render? We could also show the number of components
181
  # identified.
182
- images = [gr.Image(type="pil", label=f"Part {idx + 1}", visible=False) for idx in range(MAX_PARTS)]
 
 
 
183
  for idx, image_comp in enumerate(images):
184
- image_comp.select(get_trigger(idx), inputs=[], outputs=image_comp, api_name=False)
 
 
 
 
185
 
186
  submit_btn.click(
187
  fn=predict, inputs=[rgb_image, depth_image, intrinsics, num_samples], outputs=images, api_name=False
 
26
  )
27
  NUM_SAMPLES = 10
28
 
29
+ outputs = {}
30
 
31
 
32
  def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_samples: int) -> list[Any]:
33
  global outputs
34
 
35
+ def find_images(path: str) -> dict[str, list[str]]:
 
 
 
 
 
 
 
 
 
36
  """Scrape folders for all generated gif files."""
37
  images = {}
38
  for file in os.listdir(path):
 
53
  else:
54
  os.remove(full_path)
55
 
56
+ if not rgb_image:
57
+ gr.Error("You must provide an RGB image before running the model.")
58
+ return [None] * 5
59
+
60
+ if not depth_image:
61
+ gr.Error("You must provide a depth image before running the model.")
62
+ return [None] * 5
63
+
64
  cfg = setup_cfg(ARGS)
65
 
66
  engine.launch(
 
79
 
80
  # process output
81
  # TODO: may want to select these in decreasing order of score
82
+ outputs[rgb_image] = []
83
  image_files = find_images(ARGS.output)
 
84
  for count, part in enumerate(image_files):
85
  if count < MAX_PARTS:
86
+ outputs[rgb_image].append([Image.open(im) for im in image_files[part]])
87
 
88
  return [
89
+ *[gr.update(value=out[0], visible=True) for out in outputs[rgb_image]],
90
  *[gr.update(visible=False) for _ in range(MAX_PARTS - len(outputs))],
91
  ]
92
 
93
 
94
  def get_trigger(idx: int, fps: int = 40, oscillate: bool = True):
95
+ def iter_images(rgb_image: str):
96
+ if not rgb_image or rgb_image not in outputs:
97
+ gr.Warning("You must upload an image and run the model before you can view the output.")
98
+
99
+ elif idx < len(outputs[rgb_image]):
100
+ for im in outputs[rgb_image][idx]:
101
  time.sleep(1.0 / fps)
102
  yield im
103
  if oscillate:
104
+ for im in reversed(outputs[rgb_image][idx]):
105
  time.sleep(1.0 / fps)
106
  yield im
107
 
108
  else:
109
+ gr.Error("Could not find any images to load into this module.")
110
 
111
  return iter_images
112
 
113
 
114
+ def clear_outputs():
115
+ return [gr.update(value=None, visible=(idx == 0)) for idx in range(MAX_PARTS)]
116
+
117
+
118
  with gr.Blocks() as demo:
119
  gr.Markdown(
120
  """
 
182
  )
183
 
184
  submit_btn = gr.Button("Run model")
185
+ explanation = gr.Markdown(value="# Output\nClick on an image to see an animation of the part motion.")
186
 
187
  # TODO: do we want to set a maximum limit on how many parts we render? We could also show the number of components
188
  # identified.
189
+ images = [
190
+ gr.Image(type="pil", label=f"Part {idx + 1}", show_download_button=False, visible=(idx == 0))
191
+ for idx in range(MAX_PARTS)
192
+ ]
193
  for idx, image_comp in enumerate(images):
194
+ image_comp.select(get_trigger(idx), inputs=rgb_image, outputs=image_comp, api_name=False)
195
+
196
+ # if user changes input, clear output images
197
+ rgb_image.change(clear_outputs, inputs=[], outputs=images, api_name=False)
198
+ depth_image.change(clear_outputs, inputs=[], outputs=images, api_name=False)
199
 
200
  submit_btn.click(
201
  fn=predict, inputs=[rgb_image, depth_image, intrinsics, num_samples], outputs=images, api_name=False