Spaces:
Runtime error
Runtime error
File size: 2,418 Bytes
6f49966 b5baf02 2b6c2bd b5baf02 6f49966 b5baf02 2b6c2bd b5baf02 6f49966 b5baf02 6f49966 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import tempfile
import os
from sam_segment import predict_masks_with_sam
from lama_inpaint import inpaint_img_with_lama
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
show_mask, show_points
def mkstemp(suffix, dir=None):
fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
os.close(fd)
return Path(path)
def get_masked_img(img, point_coords):
point_labels = [1]
dilate_kernel_size = 15
device = "cuda" if torch.cuda.is_available() else "cpu"
masks, _, _ = predict_masks_with_sam(
img,
[point_coords],
point_labels,
model_type="vit_h",
ckpt_p="pretrained_models/sam_vit_h_4b8939.pth",
device=device,
)
masks = masks.astype(np.uint8) * 255
# dilate mask to avoid unmasked edge effect
if dilate_kernel_size is not None:
masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
figs = []
for idx, mask in enumerate(masks):
# save the pointed and masked image
tmp_p = mkstemp(".png")
dpi = plt.rcParams['figure.dpi']
height, width = img.shape[:2]
fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
plt.imshow(img)
plt.axis('off')
# show_points(plt.gca(), [point_coords], point_labels,
# size=(width*0.04)**2)
# plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
show_mask(plt.gca(), mask, random_color=False)
plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
figs.append(fig)
plt.close()
return figs
with gr.Blocks() as demo:
with gr.Row():
img = gr.Image(label="Image")
with gr.Row(label="Image with Segmentation Mask"):
img_with_mask_0 = gr.Plot()
img_with_mask_1 = gr.Plot()
img_with_mask_2 = gr.Plot()
with gr.Row():
w = gr.Number()
h = gr.Number()
predict_mask = gr.Button("Predict Mask Using SAM")
def get_select_coords(evt: gr.SelectData):
return evt.index[0], evt.index[1]
img.select(get_select_coords, [], [w, h])
predict_mask.click(
get_masked_img,
[img, [w, h]],
[img_with_mask_0, img_with_mask_1, img_with_mask_2]
)
if __name__ == "__main__":
demo.launch() |