Spaces:
Sleeping
Sleeping
import copy | |
import numpy as np | |
import torch | |
import sys | |
sys.path.append("./") | |
from models import sam_model_registry | |
from models.grasp_mods import modify_forward | |
from models.utils.transforms import ResizeLongestSide | |
from gradio_image_prompter import ImagePrompter | |
from structures.grasp_box import GraspCoder | |
img_resize = ResizeLongestSide(1024) | |
import cv2 | |
import gradio as gr | |
from models.grasp_mods import add_inference_method | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_type = "vit_b" | |
mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis] | |
std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis] | |
sam = sam_model_registry[model_type]() | |
sam.to(device=device) | |
sam.forward = modify_forward(sam) | |
sam.infer = add_inference_method(sam) | |
pretrained_model_path = "E:/epoch_9_step_535390.pth" | |
if pretrained_model_path != "": | |
sd = torch.load(pretrained_model_path) | |
# strip prefix "module." from keys | |
new_sd = {} | |
for k, v in sd.items(): | |
if k.startswith("module."): | |
k = k[7:] | |
new_sd[k] = v | |
sam.load_state_dict(new_sd) | |
sam.eval() | |
def predict(input, topk): | |
np_image = input["image"] | |
points = input["points"] | |
orig_size = np_image.shape[:2] | |
# normalize image | |
np_image = np_image.transpose(2, 0, 1) | |
image = (np_image - mean) / std | |
image = torch.tensor(image).float().to(device) | |
image = image.unsqueeze(0) | |
t_image = img_resize.apply_image_torch(image) | |
t_orig_size = t_image.shape[-2:] | |
# pad to 1024x1024 | |
t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2])) | |
# get box prompt | |
valid_boxes = [] | |
for point in points: | |
x1, y1, type1, x2, y2, type2 = point | |
if type1 == 2 and type2 == 3: | |
valid_boxes.append([x1, y1, x2, y2]) | |
if len(valid_boxes) == 0: | |
return np_image | |
t_boxes = np.array(valid_boxes) | |
t_boxes = img_resize.apply_boxes(t_boxes, orig_size) | |
box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device) | |
batched_inputs = [{"image": t_image[0], "boxes": box_torch}] | |
with torch.no_grad(): | |
outputs = sam.infer(batched_inputs, multimask_output=False) | |
# visualize and post on tensorboard | |
# recover image | |
recovered_img = batched_inputs[0]['image'].cpu().numpy() | |
recovered_img = recovered_img * std + mean | |
recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255) | |
for i in range(len(outputs.pred_masks)): | |
# get predicted mask | |
pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5 | |
pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2) | |
# get predicted grasp | |
pred_logits = outputs.logits[i].detach().cpu().numpy() | |
top_ind = pred_logits[:, 0].argsort()[-topk:][::-1] | |
pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind] | |
coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp) | |
_ = coded_grasp.decode() | |
decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos) | |
# draw mask | |
mask_color = np.array([0, 255, 0])[None, None, :] | |
recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5 | |
# draw grasp | |
recovered_img = np.ascontiguousarray(recovered_img) | |
for grasp in decoded_grasp: | |
grasp = grasp.astype(int) | |
cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1) | |
cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1) | |
cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2) | |
cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2) | |
recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]] | |
# resize to original size | |
recovered_img = cv2.resize(recovered_img, (orig_size[0], orig_size[1])) | |
return recovered_img | |
if __name__ == "__main__": | |
app = gr.Blocks(title="GraspAnything") | |
with app: | |
gr.Markdown(""" | |
# GraspAnything <br> | |
Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object. | |
""") | |
with gr.Column(): | |
prompter = ImagePrompter(show_label=False) | |
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps") | |
with gr.Column(): | |
image_output = gr.Image() | |
btn = gr.Button("Generate!") | |
btn.click(predict, | |
inputs=[prompter, top_k], | |
outputs=[image_output]) | |
app.launch() | |