opdmulti-demo / app.py
atwang's picture
attempt to fix hf space setup
3f76c42
import os
import re
import shutil
import time
from types import SimpleNamespace
from typing import Any, Callable, Generator
import gradio as gr
import numpy as np
from detectron2 import engine
from huggingface_hub import hf_hub_download
from natsort import natsorted
from PIL import Image
from inference import main, setup_cfg
# internal settings
NUM_PROCESSES = 1
CROP = False
SCORE_THRESHOLD = 0.8
MAX_PARTS = 5 # TODO: we can replace this by having a slider and a single image visualization component rather than multiple components
HF_MODEL_PATH = {"repo_id": "3dlg-hcvc/opdmulti-motion-state-rgb-model", "filename": "pytorch_model.pth"}
ARGS = SimpleNamespace(
config_file="configs/coco/instance-segmentation/swin/opd_v1_real.yaml",
model=None,
input_format="RGB",
output=".output",
cpu=True,
)
NUM_SAMPLES = 10
# this variable holds the current state of results, as the user will need to be able to "reload" the results in order
# to visualize the demo again. The output images are cached by the temporary path of the image, meaning that multiple
# users should be able to simultaneously run the demo. Gradio should be able to handle the case where multiple distinct
# images are uploaded with the same name, as I believe the caching of temp path is based on base64 encoding, not the
# filename itself.
# TODO: right now there is no gc system for outputs, which means if there is enough traffic per unit time such that the
# outputs are all generated on the same system instantiation of the code, the RAM could max out, acknowledging also that
# this is not designed to run on GPU and so the model and all will also need to be stored in CPU memory. Solutions could
# include
# 1. a caching design to remove old results periodically, especially if the image is reset;
# 2. caching results on disk rather than in memory, since the cap is higher; or
# 3. figuring out some way to cache results in browser instead of in the backend (couldn't figure out a way to do this
# earlier.
outputs: dict[str, list[list[Image.Image]]] = {}
def predict(rgb_image: str, depth_image: str, intrinsic: np.ndarray, num_samples: int) -> list[Any]:
"""
Run model on input image and generate output visualizations.
:param rgb_image: local path to RGB image file, used for model prediction and visualization
:param depth_image: local path to depth image file, used for visualization
:param intrinsic: array of dimension (3, 3) representing the intrinsic matrix of the camera
:param num_samples: number of visualization states to generate.
:return: list of updates to make to image components to visualize first image of visualization sequence, or
otherwise to hide an image component from visualization.
"""
global outputs
def find_images(path: str) -> dict[str, list[str]]:
"""Scrape folders for all generated image files."""
images = {}
for file in os.listdir(path):
sub_path = os.path.join(path, file)
if os.path.isdir(sub_path):
images[file] = []
for image_file in natsorted(os.listdir(sub_path)):
if re.match(r".*\.png$", image_file):
images[file].append(os.path.join(sub_path, image_file))
return images
# clear old predictions
# TODO: might be a better place for this than at the beginning of every invocation
os.makedirs(ARGS.output, exist_ok=True)
for path in os.listdir(ARGS.output):
full_path = os.path.join(ARGS.output, path)
if os.path.isdir(full_path):
shutil.rmtree(full_path)
else:
os.remove(full_path)
if not rgb_image:
gr.Error("You must provide an RGB image before running the model.")
return [None] * 5
if not depth_image:
gr.Error("You must provide a depth image before running the model.")
return [None] * 5
# run model
ARGS.model = hf_hub_download(repo_id=HF_MODEL_PATH["repo_id"], filename=HF_MODEL_PATH["filename"])
cfg = setup_cfg(ARGS)
engine.launch(
main,
NUM_PROCESSES,
args=(
cfg,
rgb_image,
depth_image,
intrinsic,
num_samples,
CROP,
SCORE_THRESHOLD,
),
)
# process output
# TODO: may want to select these in decreasing order of score
outputs[rgb_image] = []
image_files = find_images(ARGS.output)
for count, part in enumerate(image_files):
if count < MAX_PARTS: # only visualize up to MAX_PARTS parts
outputs[rgb_image].append([Image.open(im) for im in image_files[part]])
return [
*[gr.update(value=out[0], visible=True) for out in outputs[rgb_image]],
*[gr.update(visible=False) for _ in range(MAX_PARTS - len(outputs))],
]
def get_trigger(
idx: int, fps: int = 15, oscillate: bool = True
) -> Callable[[str], Generator[Image.Image, None, None]]:
"""
Return event listener trigger function for image component to animate image sequence.
:param idx: index of part to animate from output
:param fps: approximate rate at which images should be cycled through in frames per second. Note that the fps cannot
be higher than the rate at which images can be returned and rendered. Defaults to 40
:param oscillate: if True, animates part in reverse after running from start to end. Defaults to True
"""
def iter_images(rgb_image: str) -> Generator[Image.Image, None, None]:
"""Iterator to yield sequence of images for rendering, based on temp RGB image path"""
start_time = time.time()
def wait_until_next_frame(frame_count: int) -> None:
"""wait until appropriate time per the specified fps, relative to start time of iteration"""
time_to_sleep = max(frame_count / fps - (time.time() - start_time), 0)
if time_to_sleep <= 0:
print("[WARNING] frames cannot be rendered at the specified FPS due to processing/rendering time.")
time.sleep(time_to_sleep)
if not rgb_image or rgb_image not in outputs:
gr.Warning("You must upload an image and run the model before you can view the output.")
elif idx < len(outputs[rgb_image]):
frame_count = 0
# iterate forward
for im in outputs[rgb_image][idx]:
wait_until_next_frame(frame_count)
yield im
frame_count += 1
# iterate in reverse
if oscillate:
for im in reversed(outputs[rgb_image][idx]):
wait_until_next_frame(frame_count)
yield im
frame_count += 1
else:
gr.Error("Could not find any images to load into this module.")
return iter_images
def clear_outputs():
"""
Remove images from image components.
"""
return [gr.update(value=None, visible=(idx == 0)) for idx in range(MAX_PARTS)]
def run():
with gr.Blocks() as demo:
gr.Markdown(
"""
# OPDMulti Demo
We tackle the openable-part-detection (OPD) problem where we identify in a single-view image parts that are openable and their motion parameters. Our OPDFORMER architecture outputs segmentations for openable parts on potentially multiple objects, along with each part’s motion parameters: motion type (translation or rotation, indicated by blue or purple mask), motion axis and origin (see green arrows and points). For each openable part, we predict the motion parameters (axis and origin) in object coordinates along with an object pose prediction to convert to camera coordinates.
More information about the project, including code, can be found [here](https://3dlg-hcvc.github.io/OPDMulti/).
Upload an image to see a visualization of its range of motion below. Only the RGB image is needed for the model itself, but the depth image is required as of now for the visualization of motion.
If you know the intrinsic matrix of your camera, you can specify that here or otherwise use the default matrix which will work with any of the provided examples.
You can also change the number of samples to define the number of states in the visualization generated.
"""
)
# inputs
with gr.Row():
rgb_image = gr.Image(
image_mode="RGB", source="upload", type="filepath", label="RGB Image", show_label=True, interactive=True
)
depth_image = gr.Image(
image_mode="I;16", source="upload", type="filepath", label="Depth Image", show_label=True, interactive=True
)
intrinsic = gr.Dataframe(
value=[
[
214.85935872395834,
0.0,
125.90160319010417,
],
[
0.0,
214.85935872395834,
95.13726399739583,
],
[
0.0,
0.0,
1.0,
],
],
row_count=(3, "fixed"),
col_count=(3, "fixed"),
datatype="number",
type="numpy",
label="Intrinsic matrix",
show_label=True,
interactive=True,
)
num_samples = gr.Number(
value=NUM_SAMPLES,
label="Number of samples",
show_label=True,
interactive=True,
precision=0,
minimum=3,
maximum=20,
)
# specify examples which can be used to start
examples = gr.Examples(
examples=[
["examples/59-4860.png", "examples/59-4860_d.png"],
["examples/174-8460.png", "examples/174-8460_d.png"],
["examples/187-0.png", "examples/187-0_d.png"],
["examples/187-23040.png", "examples/187-23040_d.png"],
],
inputs=[rgb_image, depth_image],
api_name=False,
examples_per_page=2,
)
submit_btn = gr.Button("Run model")
# output
explanation = gr.Markdown(
value=f"# Output\nClick on an image to see an animation of the part motion. As of now, only up to {MAX_PARTS} parts can be visualized due to limitations of the visualizer."
)
images = [
gr.Image(type="pil", label=f"Part {idx + 1}", show_download_button=False, visible=(idx == 0))
for idx in range(MAX_PARTS)
]
for idx, image_comp in enumerate(images):
image_comp.select(get_trigger(idx), inputs=rgb_image, outputs=image_comp, api_name=False)
# if user changes input, clear output images
rgb_image.change(clear_outputs, inputs=[], outputs=images, api_name=False)
depth_image.change(clear_outputs, inputs=[], outputs=images, api_name=False)
submit_btn.click(
fn=predict, inputs=[rgb_image, depth_image, intrinsic, num_samples], outputs=images, api_name=False
)
demo.queue(api_open=False)
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
print("Starting up app...")
run()