File size: 4,963 Bytes
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cbc48e
fcdfd72
 
9c466e3
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e9f709
fcdfd72
8e9f709
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
8e9f709
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e9f709
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7d17af
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e9f709
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import copy
import numpy as np
import torch

import sys
sys.path.append("./")
from models import sam_model_registry
from models.grasp_mods import modify_forward
from models.utils.transforms import ResizeLongestSide

from gradio_image_prompter import ImagePrompter
from structures.grasp_box import GraspCoder
img_resize = ResizeLongestSide(1024)
import cv2

import gradio as gr

from models.grasp_mods import add_inference_method

device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "vit_b"

mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis]
std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis]

sam = sam_model_registry[model_type]()
sam.to(device=device)

sam.forward = modify_forward(sam)
sam.infer = add_inference_method(sam)

pretrained_model_path = "./epoch_39_step_415131.pth"

if pretrained_model_path != "":
    sd = torch.load(pretrained_model_path, map_location='cpu')
    # strip prefix "module." from keys
    new_sd = {}
    for k, v in sd.items():
        if k.startswith("module."):
            k = k[7:]
        new_sd[k] = v
    sam.load_state_dict(new_sd)

sam.eval()

def predict(input, topk):
    np_image = input["image"]
    points = input["points"]
    orig_size = np_image.shape[:2]
    # normalize image
    np_image = np_image.transpose(2, 0, 1)

    image = (np_image - mean) / std
    image = torch.tensor(image).float().to(device)
    image = image.unsqueeze(0)
    t_image = img_resize.apply_image_torch(image)
    t_orig_size = t_image.shape[-2:]
    # pad to 1024x1024
    pixel_mask = torch.ones(1, t_orig_size[0], t_orig_size[1], device=device)
    t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
    pixel_mask = torch.nn.functional.pad(pixel_mask, (0, 1024 - t_orig_size[1], 0, 1024 - t_orig_size[0]))

    # get box prompt
    valid_boxes = []
    for point in points:
        x1, y1, type1, x2, y2, type2 = point
        if type1 == 2 and type2 == 3:
            valid_boxes.append([x1, y1, x2, y2])
    if len(valid_boxes) == 0:
        return np_image
    t_boxes = np.array(valid_boxes)
    t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
    box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
    batched_inputs = [{"image": t_image[0], "boxes": box_torch, "pixel_mask": pixel_mask}]
    with torch.no_grad():
        outputs = sam.infer(batched_inputs, multimask_output=False)
    # visualize and post on tensorboard
    # recover image
    recovered_img = batched_inputs[0]['image'].cpu().numpy()
    recovered_img = recovered_img * std + mean
    recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255)

    for i in range(len(outputs.pred_masks)):
        # get predicted mask
        pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5
        pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2)

        # get predicted grasp
        pred_logits = outputs.logits[i].detach().cpu().numpy()
        top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
        pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
        coded_grasp = GraspCoder(t_orig_size[0], t_orig_size[1], None, grasp_annos_reformat=pred_grasp)
        _ = coded_grasp.decode()
        decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)

        # draw mask
        mask_color = np.array([0, 255, 0])[None, None, :]
        recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5

        # draw grasp
        recovered_img = np.ascontiguousarray(recovered_img)
        for grasp in decoded_grasp:
            grasp = grasp.astype(int)
            cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1)
            cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1)
            cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2)
            cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2)

    recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]]
    # resize to original size
    recovered_img = cv2.resize(recovered_img, (orig_size[1], orig_size[0]))
    return recovered_img

if __name__ == "__main__":
    app = gr.Blocks(title="GraspAnything")
    with app:
        gr.Markdown("""
        # GraspAnything <br>
        Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object.
        """)
        with gr.Column():
            prompter = ImagePrompter(show_label=False)
            top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps")
        with gr.Column():
            image_output = gr.Image()
        btn = gr.Button("Generate!")
        btn.click(predict,
                  inputs=[prompter, top_k],
                  outputs=[image_output])
    app.launch()