import numpy as np import gradio as gr import argparse import pdb import torch import torch.nn.functional as F import torchvision.transforms as transforms import cv2 from PIL import Image import os import subprocess import matplotlib as mpl import matplotlib.pyplot as plt mpl.use('agg') from monoarti.model import build_demo_model from monoarti.detr.misc import interpolate from monoarti.vis_utils import draw_properties, draw_affordance, draw_localization from monoarti.detr import box_ops from monoarti import axis_ops, depth_ops mask_source_draw = "draw a mask on input image" mask_source_segment = "type what to detect below" def change_radio_display(task_type, mask_source_radio): text_prompt_visible = True inpaint_prompt_visible = False mask_source_radio_visible = False if task_type == "inpainting": inpaint_prompt_visible = True if task_type == "inpainting" or task_type == "remove": mask_source_radio_visible = True if mask_source_radio == mask_source_draw: text_prompt_visible = False return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible) os.makedirs('temp', exist_ok=True) # initialize model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device = 'cpu' model = build_demo_model().to(device) checkpoint_path = 'checkpoint_20230515.pth' if not os.path.exists(checkpoint_path): print("get {}".format(checkpoint_path)) result = subprocess.run(['wget', 'https://fouheylab.eecs.umich.edu/~syqian/3DOI/{}'.format(checkpoint_path)], check=True) print('wget {} result = {}'.format(checkpoint_path, result)) loaded_data = torch.load(checkpoint_path, map_location=device) state_dict = loaded_data["model"] model.load_state_dict(state_dict, strict=True) data_transforms = transforms.Compose([ transforms.Resize((768, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) movable_imap = { 0: 'one_hand', 1: 'two_hands', 2: 'fixture', -100: 'n/a', } rigid_imap = { 1: 'yes', 0: 'no', 2: 'bad', -100: 'n/a', } kinematic_imap = { 0: 'freeform', 1: 'rotation', 2: 'translation', -100: 'n/a' } action_imap = { 0: 'free', 1: 'pull', 2: 'push', -100: 'n/a', } def run_model(input_image): image = input_image['image'] input_width, input_height = image.size image_tensor = data_transforms(image) image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.to(device) mask = np.array(input_image['mask'])[:, :, :3].sum(axis=2) if mask.sum() == 0: raise gr.Error("No query point!") ret, thresh = cv2.threshold(mask.astype(np.uint8), 50, 255, cv2.THRESH_BINARY) contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) M = cv2.moments(contours[0]) x = round(M['m10'] / M['m00'] / input_width * 1024) # width y = round(M['m01'] / M['m00'] / input_height * 768) # height keypoints = torch.ones((1, 15, 2)).long() * -1 keypoints[:, :, 0] = x keypoints[:, :, 1] = y keypoints = keypoints.to(device) valid = torch.zeros((1, 15)).bool() valid[:, 0] = True valid = valid.to(device) out = model(image_tensor, valid, keypoints, bbox=None, masks=None, movable=None, rigid=None, kinematic=None, action=None, affordance=None, affordance_map=None, depth=None, axis=None, fov=None, backward=False) # visualization rgb = np.array(image.resize((1024, 768))) image_size = (768, 1024) bbox_preds = out['pred_boxes'] mask_preds = out['pred_masks'] mask_preds = interpolate(mask_preds, size=image_size, mode='bilinear', align_corners=False) mask_preds = mask_preds.sigmoid() > 0.5 movable_preds = out['pred_movable'].argmax(dim=-1) rigid_preds = out['pred_rigid'].argmax(dim=-1) kinematic_preds = out['pred_kinematic'].argmax(dim=-1) action_preds = out['pred_action'].argmax(dim=-1) axis_preds = out['pred_axis'] depth_preds = out['pred_depth'] affordance_preds = out['pred_affordance'] affordance_preds = interpolate(affordance_preds, size=image_size, mode='bilinear', align_corners=False) if depth_preds is not None: depth_preds = interpolate(depth_preds, size=image_size, mode='bilinear', align_corners=False) i = 0 instances = [] predictions = [] for j in range(15): if not valid[i, j]: break export_dir = './temp' img_name = 'temp' axis_center = box_ops.box_xyxy_to_cxcywh(bbox_preds[i]).clone() axis_center[:, 2:] = axis_center[:, :2] axis_pred = axis_preds[i] axis_pred_norm = F.normalize(axis_pred[:, :2]) axis_pred = torch.cat((axis_pred_norm, axis_pred[:, 2:]), dim=-1) src_axis_xyxys = axis_ops.line_angle_to_xyxy(axis_pred, center=axis_center) # original image + keypoint vis = rgb.copy() kp = keypoints[i, j].cpu().numpy() vis = cv2.circle(vis, kp, 24, (255, 255, 255), -1) vis = cv2.circle(vis, kp, 20, (31, 73, 125), -1) vis = Image.fromarray(vis) predictions.append(vis) # physical properties movable_pred = movable_preds[i, j].item() rigid_pred = rigid_preds[i, j].item() kinematic_pred = kinematic_preds[i, j].item() action_pred = action_preds[i, j].item() output_path = os.path.join(export_dir, '{}_kp_{:0>2}_02_phy.png'.format(img_name, j)) draw_properties(output_path, movable_pred, rigid_pred, kinematic_pred, action_pred) property_pred = Image.open(output_path) predictions.append(property_pred) # box mask axis axis_pred = src_axis_xyxys[j] if kinematic_imap[kinematic_pred] != 'rotation': axis_pred = [-1, -1, -1, -1] img_path = os.path.join(export_dir, '{}_kp_{:0>2}_03_loc.png'.format(img_name, j)) draw_localization( rgb, img_path, None, mask_preds[i, j].cpu().numpy(), axis_pred, colors=None, alpha=0.6, ) localization_pred = Image.open(img_path) predictions.append(localization_pred) # affordance affordance_pred = affordance_preds[i, j].sigmoid() affordance_pred = affordance_pred.detach().cpu().numpy() #[:, :, np.newaxis] aff_path = os.path.join(export_dir, '{}_kp_{:0>2}_04_affordance.png'.format(img_name, j)) aff_vis = draw_affordance(rgb, aff_path, affordance_pred) predictions.append(aff_vis) # depth depth_pred = depth_preds[i] depth_pred_metric = depth_pred[0] * 0.945 + 0.658 depth_pred_metric = depth_pred_metric.detach().cpu().numpy() fig = plt.figure() plt.imshow(depth_pred_metric, cmap=mpl.colormaps['plasma']) plt.axis('off') depth_path = os.path.join(export_dir, '{}_kp_{:0>2}_05_depth.png'.format(img_name, j)) plt.savefig(depth_path, bbox_inches='tight', pad_inches=0) plt.close(fig) depth_pred = Image.open(depth_path) predictions.append(depth_pred) return predictions examples = [ 'examples/AR_4ftr44oANPU_34_900_35.jpg', 'examples/AR_0Mi_dDnmF2Y_6_2610_15.jpg', 'examples/EK_0037_P28_101_frame_0000031096.jpg', 'examples/EK_0056_P04_121_frame_0000018401.jpg', 'examples/taskonomy_bonfield_point_42_view_6_domain_rgb.png', 'examples/taskonomy_wando_point_156_view_3_domain_rgb.png', ] title = 'Understanding 3D Object Interaction from a Single Image' description = """

Project Page | Paper | Code

Gradio demo for Understanding 3D Object Interaction from a Single Image. \n You may click on of the examples or upload your own image. \n After having the image, you can click on the image to create a single query point. You can then hit Run.\n Our approach can predict 3D object interaction from a single image, including Movable (one hand or two hands), Rigid, Articulation type and axis, Action, Bounding box, Mask, Affordance and Depth. """ # noqa with gr.Blocks().queue() as demo: gr.Markdown("

" + title + "

") gr.Markdown(description) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload", brush_radius=20) run_button = gr.Button(label="Run") with gr.Column(): examples_handler = gr.Examples( examples=examples, inputs=input_image, examples_per_page=10, ) with gr.Row(): with gr.Column(scale=1): query_image = gr.outputs.Image(label="Image + Query", type="pil") with gr.Column(scale=1): pred_localization = gr.outputs.Image(label="Localization", type="pil") with gr.Column(scale=1): pred_properties = gr.outputs.Image(label="Properties", type="pil") with gr.Row(): with gr.Column(scale=1): pred_affordance = gr.outputs.Image(label="Affordance", type="pil") with gr.Column(scale=1): pred_depth = gr.outputs.Image(label="Depth", type="pil") with gr.Column(scale=1): pass output_components = [query_image, pred_properties, pred_localization, pred_affordance, pred_depth] run_button.click(fn=run_model, inputs=[input_image], outputs=output_components) if __name__ == "__main__": demo.launch(server_name='0.0.0.0')