import gradio as gr import numpy as np import torch import torch.nn.functional as F from PIL import Image # mm libs from mmdet.registry import MODELS from mmdet.structures import DetDataSample from mmdet.visualization import DetLocalVisualizer from mmengine import Config, print_log from mmengine.structures import InstanceData from mmdet.datasets.coco_panoptic import CocoPanopticDataset from PIL import ImageDraw import spaces IMG_SIZE = 1024 TITLE = "
OMG-Seg: Is One Model Good Enough For All Segmentation?
" CSS = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" model_cfg = Config.fromfile('app/configs/m2_convl.py') model = MODELS.build(model_cfg.model) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device=device) model = model.eval() model.init_weights() mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None] std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None] visualizer = DetLocalVisualizer() examples = [ ["assets/000000000139.jpg"], ["assets/000000000285.jpg"], ["assets/000000000632.jpg"], ["assets/000000000724.jpg"], ] class IMGState: def __init__(self): self.img = None self.selected_points = [] self.available_to_set = True def set_img(self, img): self.img = img self.available_to_set = False def clear(self): self.img = None self.selected_points = [] self.available_to_set = True def clean(self): self.selected_points = [] @property def available(self): return self.available_to_set @classmethod def cls_clean(cls, state): state.clean() return Image.fromarray(state.img), None @classmethod def cls_clear(cls, state): state.clear() return None, None def store_img(img, img_state): w, h = img.size scale = IMG_SIZE / max(w, h) new_w = int(w * scale) new_h = int(h * scale) img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR) img_numpy = np.array(img) img_state.set_img(img_numpy) print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current') return img, None def get_points_with_draw(image, img_state, evt: gr.SelectData): x, y = evt.index[0], evt.index[1] print_log(f"Point: {x}_{y}", logger='current') point_radius, point_color = 10, (97, 217, 54) img_state.selected_points.append([x, y]) if len(img_state.selected_points) > 0: img_state.selected_points = img_state.selected_points[-1:] image = Image.fromarray(img_state.img) draw = ImageDraw.Draw(image) draw.ellipse( [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color, ) return image @spaces.GPU() def segment_point(image, img_state, mode): output_img = img_state.img h, w = output_img.shape[:2] img_tensor = torch.tensor(output_img, device=device, dtype=torch.float32).permute((2, 0, 1))[None] img_tensor = (img_tensor - mean) / std im_w = w if w % 32 == 0 else w // 32 * 32 + 32 im_h = h if h % 32 == 0 else h // 32 * 32 + 32 img_tensor = F.pad(img_tensor, (0, im_w - w, 0, im_h - h), 'constant', 0) if len(img_state.selected_points) > 0: input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device) batch_data_samples = [DetDataSample()] selected_point = torch.cat([input_points - 3, input_points + 3], 1) gt_instances = InstanceData( point_coords=selected_point, ) pb_labels = torch.zeros(len(gt_instances), dtype=torch.long, device=device) gt_instances.bp = pb_labels batch_data_samples[0].gt_instances = gt_instances batch_data_samples[0].data_tag = 'sam' batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) is_prompt = True else: batch_data_samples = [DetDataSample()] batch_data_samples[0].data_tag = 'coco' batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) is_prompt = False with torch.no_grad(): results = model.predict(img_tensor, batch_data_samples, rescale=False) masks = results[0] if is_prompt: masks = masks[0, :h, :w] masks = masks > 0. # no sigmoid rgb_shape = tuple(list(masks.shape) + [3]) color = np.zeros(rgb_shape, dtype=np.uint8) color[masks] = np.array([97, 217, 54]) output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8) output_img = Image.fromarray(output_img) else: if mode == 'Panoptic Segmentation': output_img = visualizer._draw_panoptic_seg( output_img, masks['pan_results'].to('cpu').numpy(), classes=CocoPanopticDataset.METAINFO['classes'], palette=CocoPanopticDataset.METAINFO['palette'] ) elif mode == 'Instance Segmentation': masks['ins_results'] = masks['ins_results'][masks['ins_results'].scores > .2] output_img = visualizer._draw_instances( output_img, masks['ins_results'].to('cpu').numpy(), classes=CocoPanopticDataset.METAINFO['classes'], palette=CocoPanopticDataset.METAINFO['palette'] ) return image, output_img def register_title(): with gr.Row(): with gr.Column(scale=1): gr.Markdown(TITLE) def register_point_mode(): with gr.Tab("Point mode"): img_state = gr.State(IMGState()) with gr.Row(variant="panel"): with gr.Column(scale=1): img_p = gr.Image(label="Input Image", type="pil") with gr.Column(scale=1): segm_p = gr.Image(label="Segment", interactive=False, type="pil") with gr.Row(): with gr.Column(): mode = gr.Radio( ["Panoptic Segmentation", "Instance Segmentation"], label="Mode", value="Panoptic Segmentation", info="Please select the segmentation mode. (Ignored if provided with prompt.)" ) with gr.Row(): with gr.Column(): segment_btn = gr.Button("Segment", variant="primary") with gr.Column(): clean_btn = gr.Button("Clean Prompts", variant="secondary") with gr.Row(): with gr.Column(): gr.Markdown("Try some of the examples below ⬇️") gr.Examples( examples=examples, inputs=[img_p, img_state], outputs=[img_p, segm_p], examples_per_page=4, fn=store_img, run_on_click=True ) img_p.upload( store_img, [img_p, img_state], [img_p, segm_p] ) img_p.select( get_points_with_draw, [img_p, img_state], img_p ) segment_btn.click( segment_point, [img_p, img_state, mode], [img_p, segm_p] ) clean_btn.click( IMGState.cls_clean, img_state, [img_p, segm_p] ) img_p.clear( IMGState.cls_clear, img_state, [img_p, segm_p] ) def build_demo(): with gr.Blocks(css=CSS, title="RAP-SAM") as _demo: register_title() register_point_mode() return _demo if __name__ == '__main__': demo = build_demo() demo.queue(api_open=False) demo.launch(server_name='0.0.0.0')