|
import gradio as gr |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
|
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures import DetDataSample |
|
from mmdet.visualization import DetLocalVisualizer |
|
from mmengine import Config, print_log |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.datasets.coco_panoptic import CocoPanopticDataset |
|
|
|
from PIL import ImageDraw |
|
|
|
IMG_SIZE = 1024 |
|
|
|
TITLE = "<center><strong><font size='8'>🚀RAP-SAM: Towards Real-Time All-Purpose Segment Anything<font></strong></center>" |
|
CSS = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" |
|
|
|
model_cfg = Config.fromfile('app/configs/rap_sam_r50_12e_adaptor.py') |
|
|
|
model = MODELS.build(model_cfg.model) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device=device) |
|
model = model.eval() |
|
model.init_weights() |
|
|
|
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None] |
|
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None] |
|
|
|
visualizer = DetLocalVisualizer() |
|
|
|
examples = [ |
|
["assets/000000000139.jpg"], |
|
["assets/000000000285.jpg"], |
|
["assets/000000000632.jpg"], |
|
["assets/000000000724.jpg"], |
|
] |
|
|
|
|
|
class IMGState: |
|
def __init__(self): |
|
self.img = None |
|
self.selected_points = [] |
|
self.available_to_set = True |
|
|
|
def set_img(self, img): |
|
self.img = img |
|
self.available_to_set = False |
|
|
|
def clear(self): |
|
self.img = None |
|
self.selected_points = [] |
|
self.available_to_set = True |
|
|
|
def clean(self): |
|
self.selected_points = [] |
|
|
|
@property |
|
def available(self): |
|
return self.available_to_set |
|
|
|
@classmethod |
|
def cls_clean(cls, state): |
|
state.clean() |
|
return Image.fromarray(state.img), None |
|
|
|
@classmethod |
|
def cls_clear(cls, state): |
|
state.clear() |
|
return None, None |
|
|
|
|
|
def store_img(img, img_state): |
|
w, h = img.size |
|
scale = IMG_SIZE / max(w, h) |
|
new_w = int(w * scale) |
|
new_h = int(h * scale) |
|
img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR) |
|
img_numpy = np.array(img) |
|
img_state.set_img(img_numpy) |
|
print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current') |
|
|
|
return img, None |
|
|
|
|
|
def get_points_with_draw(image, img_state, evt: gr.SelectData): |
|
x, y = evt.index[0], evt.index[1] |
|
print_log(f"Point: {x}_{y}", logger='current') |
|
point_radius, point_color = 10, (97, 217, 54) |
|
|
|
img_state.selected_points.append([x, y]) |
|
if len(img_state.selected_points) > 0: |
|
img_state.selected_points = img_state.selected_points[-1:] |
|
image = Image.fromarray(img_state.img) |
|
|
|
draw = ImageDraw.Draw(image) |
|
draw.ellipse( |
|
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], |
|
fill=point_color, |
|
) |
|
return image |
|
|
|
|
|
def segment_point(image, img_state): |
|
output_img = img_state.img |
|
h, w = output_img.shape[:2] |
|
|
|
img_tensor = torch.tensor(output_img, device=device, dtype=torch.float32).permute((2, 0, 1))[None] |
|
img_tensor = (img_tensor - mean) / std |
|
|
|
im_w = w if w % 32 == 0 else w // 32 * 32 + 32 |
|
im_h = h if h % 32 == 0 else h // 32 * 32 + 32 |
|
img_tensor = F.pad(img_tensor, (0, im_w - w, 0, im_h - h), 'constant', 0) |
|
|
|
if len(img_state.selected_points) > 0: |
|
input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device) |
|
batch_data_samples = [DetDataSample()] |
|
selected_point = torch.cat([input_points - 3, input_points + 3], 1) |
|
gt_instances = InstanceData( |
|
point_coords=selected_point, |
|
) |
|
pb_labels = torch.ones(len(gt_instances), dtype=torch.long, device=device) |
|
gt_instances.pb_labels = pb_labels |
|
batch_data_samples[0].gt_instances_collected = gt_instances |
|
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) |
|
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) |
|
is_prompt = True |
|
else: |
|
batch_data_samples = [DetDataSample()] |
|
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) |
|
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) |
|
is_prompt = False |
|
with torch.no_grad(): |
|
masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples) |
|
|
|
assert len(masks) == 1 |
|
masks = masks[0] |
|
if is_prompt: |
|
masks = masks[0, :h, :w] |
|
masks = masks > 0. |
|
rgb_shape = tuple(list(masks.shape) + [3]) |
|
color = np.zeros(rgb_shape, dtype=np.uint8) |
|
color[masks] = np.array([97, 217, 54]) |
|
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8) |
|
output_img = Image.fromarray(output_img) |
|
else: |
|
output_img = visualizer._draw_panoptic_seg( |
|
output_img, |
|
masks['pan_results'].to('cpu').numpy(), |
|
classes=CocoPanopticDataset.METAINFO['classes'], |
|
palette=CocoPanopticDataset.METAINFO['palette'] |
|
) |
|
return image, output_img |
|
|
|
|
|
def register_title(): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown(TITLE) |
|
|
|
|
|
def register_point_mode(): |
|
with gr.Tab("Point mode"): |
|
img_state = gr.State(IMGState()) |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
img_p = gr.Image(label="Input Image", type="pil") |
|
|
|
with gr.Column(scale=1): |
|
segm_p = gr.Image(label="Segment", interactive=False, type="pil") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
segment_btn = gr.Button("Segment", variant="primary") |
|
with gr.Column(): |
|
clean_btn = gr.Button("Clean Prompts", variant="secondary") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Try some of the examples below ⬇️") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[img_p, img_state], |
|
outputs=[img_p, segm_p], |
|
examples_per_page=4, |
|
fn=store_img, |
|
run_on_click=True |
|
) |
|
|
|
img_p.upload( |
|
store_img, |
|
[img_p, img_state], |
|
[img_p, segm_p] |
|
) |
|
|
|
img_p.select( |
|
get_points_with_draw, |
|
[img_p, img_state], |
|
img_p |
|
) |
|
|
|
segment_btn.click( |
|
segment_point, |
|
[img_p, img_state], |
|
[img_p, segm_p] |
|
) |
|
|
|
clean_btn.click( |
|
IMGState.cls_clean, |
|
img_state, |
|
[img_p, segm_p] |
|
) |
|
|
|
img_p.clear( |
|
IMGState.cls_clear, |
|
img_state, |
|
[img_p, segm_p] |
|
) |
|
|
|
|
|
def build_demo(): |
|
with gr.Blocks(css=CSS, title="RAP-SAM") as _demo: |
|
register_title() |
|
register_point_mode() |
|
return _demo |
|
|
|
|
|
if __name__ == '__main__': |
|
demo = build_demo() |
|
|
|
demo.queue(api_open=False) |
|
demo.launch(server_name='0.0.0.0') |
|
|