import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
# mm libs
from mmdet.registry import MODELS
from mmdet.structures import DetDataSample
from mmdet.visualization import DetLocalVisualizer
from mmengine import Config, print_log
from mmengine.structures import InstanceData
from mmdet.datasets.coco_panoptic import CocoPanopticDataset
from PIL import ImageDraw
import spaces
IMG_SIZE = 1024
TITLE = "
OMG-Seg: Is One Model Good Enough For All Segmentation?"
CSS = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
model_cfg = Config.fromfile('app/configs/m2_convl.py')
model = MODELS.build(model_cfg.model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device)
model = model.eval()
model.init_weights()
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
visualizer = DetLocalVisualizer()
examples = [
["assets/000000000139.jpg"],
["assets/000000000285.jpg"],
["assets/000000000632.jpg"],
["assets/000000000724.jpg"],
]
class IMGState:
def __init__(self):
self.img = None
self.selected_points = []
self.available_to_set = True
def set_img(self, img):
self.img = img
self.available_to_set = False
def clear(self):
self.img = None
self.selected_points = []
self.available_to_set = True
def clean(self):
self.selected_points = []
@property
def available(self):
return self.available_to_set
@classmethod
def cls_clean(cls, state):
state.clean()
return Image.fromarray(state.img), None
@classmethod
def cls_clear(cls, state):
state.clear()
return None, None
def store_img(img, img_state):
w, h = img.size
scale = IMG_SIZE / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
img_numpy = np.array(img)
img_state.set_img(img_numpy)
print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current')
return img, None
def get_points_with_draw(image, img_state, evt: gr.SelectData):
x, y = evt.index[0], evt.index[1]
print_log(f"Point: {x}_{y}", logger='current')
point_radius, point_color = 10, (97, 217, 54)
img_state.selected_points.append([x, y])
if len(img_state.selected_points) > 0:
img_state.selected_points = img_state.selected_points[-1:]
image = Image.fromarray(img_state.img)
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
return image
@spaces.GPU()
def segment_point(image, img_state, mode):
output_img = img_state.img
h, w = output_img.shape[:2]
img_tensor = torch.tensor(output_img, device=device, dtype=torch.float32).permute((2, 0, 1))[None]
img_tensor = (img_tensor - mean) / std
im_w = w if w % 32 == 0 else w // 32 * 32 + 32
im_h = h if h % 32 == 0 else h // 32 * 32 + 32
img_tensor = F.pad(img_tensor, (0, im_w - w, 0, im_h - h), 'constant', 0)
if len(img_state.selected_points) > 0:
input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device)
batch_data_samples = [DetDataSample()]
selected_point = torch.cat([input_points - 3, input_points + 3], 1)
gt_instances = InstanceData(
point_coords=selected_point,
)
pb_labels = torch.zeros(len(gt_instances), dtype=torch.long, device=device)
gt_instances.bp = pb_labels
batch_data_samples[0].gt_instances = gt_instances
batch_data_samples[0].data_tag = 'sam'
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
is_prompt = True
else:
batch_data_samples = [DetDataSample()]
batch_data_samples[0].data_tag = 'coco'
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
is_prompt = False
with torch.no_grad():
results = model.predict(img_tensor, batch_data_samples, rescale=False)
masks = results[0]
if is_prompt:
masks = masks[0, :h, :w]
masks = masks > 0. # no sigmoid
rgb_shape = tuple(list(masks.shape) + [3])
color = np.zeros(rgb_shape, dtype=np.uint8)
color[masks] = np.array([97, 217, 54])
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
output_img = Image.fromarray(output_img)
else:
if mode == 'Panoptic Segmentation':
output_img = visualizer._draw_panoptic_seg(
output_img,
masks['pan_results'].to('cpu').numpy(),
classes=CocoPanopticDataset.METAINFO['classes'],
palette=CocoPanopticDataset.METAINFO['palette']
)
elif mode == 'Instance Segmentation':
masks['ins_results'] = masks['ins_results'][masks['ins_results'].scores > .2]
output_img = visualizer._draw_instances(
output_img,
masks['ins_results'].to('cpu').numpy(),
classes=CocoPanopticDataset.METAINFO['classes'],
palette=CocoPanopticDataset.METAINFO['palette']
)
return image, output_img
def register_title():
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(TITLE)
def register_point_mode():
with gr.Tab("Point mode"):
img_state = gr.State(IMGState())
with gr.Row(variant="panel"):
with gr.Column(scale=1):
img_p = gr.Image(label="Input Image", type="pil")
with gr.Column(scale=1):
segm_p = gr.Image(label="Segment", interactive=False, type="pil")
with gr.Row():
with gr.Column():
mode = gr.Radio(
["Panoptic Segmentation", "Instance Segmentation"],
label="Mode",
value="Panoptic Segmentation",
info="Please select the segmentation mode. (Ignored if provided with prompt.)"
)
with gr.Row():
with gr.Column():
segment_btn = gr.Button("Segment", variant="primary")
with gr.Column():
clean_btn = gr.Button("Clean Prompts", variant="secondary")
with gr.Row():
with gr.Column():
gr.Markdown("Try some of the examples below ⬇️")
gr.Examples(
examples=examples,
inputs=[img_p, img_state],
outputs=[img_p, segm_p],
examples_per_page=4,
fn=store_img,
run_on_click=True
)
img_p.upload(
store_img,
[img_p, img_state],
[img_p, segm_p]
)
img_p.select(
get_points_with_draw,
[img_p, img_state],
img_p
)
segment_btn.click(
segment_point,
[img_p, img_state, mode],
[img_p, segm_p]
)
clean_btn.click(
IMGState.cls_clean,
img_state,
[img_p, segm_p]
)
img_p.clear(
IMGState.cls_clear,
img_state,
[img_p, segm_p]
)
def build_demo():
with gr.Blocks(css=CSS, title="RAP-SAM") as _demo:
register_title()
register_point_mode()
return _demo
if __name__ == '__main__':
demo = build_demo()
demo.queue(api_open=False)
demo.launch(server_name='0.0.0.0')