Spaces:
Paused
Paused
import logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
import os | |
# os.environ['CUDA_VISIBLE_DEVICES'] = '1' | |
import cv2 | |
import imageio | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import plotly.express as px | |
import torch | |
import dash | |
from dash import Dash, Input, Output, dcc, html, State | |
from dash.exceptions import PreventUpdate | |
from .self_prompting import grounding_dino_prompt | |
def mark_image(_img, points): | |
assert(len(points) > 0) | |
img = _img.copy() | |
r = 10 | |
mark_color = np.array([255, 0, 0]).reshape(1, 1, 3) | |
for i in range(len(points)): | |
point = points[i] | |
img[point[1]-r:point[1]+r+1, point[0]-r:point[0]+r+1] = mark_color | |
return img | |
def draw_figure(fig, title, animation_frame=None): | |
fig = px.imshow(fig, animation_frame=animation_frame) | |
if animation_frame is not None: | |
# fig.update_layout(sliders = [{'visible': False}]) | |
fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 33 | |
fig.update_layout(title_text=title, showlegend=False) | |
fig.update_xaxes(showticklabels=False) | |
fig.update_yaxes(showticklabels=False) | |
return fig | |
class Sam3dGUI: | |
def __init__(self, Seg3d, debug=False): | |
ctx = { | |
'num_clicks': 0, | |
'click': [], | |
'cur_img': None, | |
'btn_clear': 0, | |
'btn_text': 0, | |
'prompt_type': 'point', | |
'show_rgb': False | |
} | |
self.ctx = ctx | |
self.Seg3d = Seg3d | |
self.debug = debug | |
self.train_idx = 0 | |
def run(self): | |
init_rgb = self.Seg3d.init_model() | |
self.ctx['cur_img'] = init_rgb | |
self.run_app(sam_pred=self.Seg3d.predictor, ctx=self.ctx, init_rgb=init_rgb) | |
def run_app(self, sam_pred, ctx, init_rgb): | |
''' | |
run dash app | |
''' | |
def query(points=None, text=None): | |
with torch.no_grad(): | |
if text is None: | |
input_point = points | |
input_label = np.ones(len(input_point)) | |
masks, scores, logits = sam_pred.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=True, | |
) | |
elif points is None: | |
input_boxes = grounding_dino_prompt(ctx['cur_img'], text) | |
boxes = torch.tensor(input_boxes)[0:1].cuda() | |
transformed_boxes = sam_pred.transform.apply_boxes_torch(boxes, ctx['cur_img'].shape[:2]) | |
masks, scores, logits = sam_pred.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes, | |
multimask_output=True, | |
) | |
masks = masks[0].cpu().numpy() | |
else: | |
raise NotImplementedError | |
fig1 = (255*masks[0, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8) | |
fig2 = (255*masks[1, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8) | |
fig3 = (255*masks[2, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8) | |
fig1 = draw_figure(fig1, 'mask0') | |
fig2 = draw_figure(fig2, 'mask1') | |
fig3 = draw_figure(fig3, 'mask2') | |
if text is None: | |
fig0 = mark_image(ctx['cur_img'], points) | |
else: | |
fig0 = ctx['cur_img'] | |
fig0 = draw_figure(fig0, 'original_image') | |
return masks, fig0, fig1, fig2, fig3 | |
# _, fig0, fig1, fig2, fig3, desc = query(np.array([[100, 100], [101, 101]])) | |
self.ctx['fig0'] = draw_figure(init_rgb, 'original_image') | |
self.ctx['fig1'] = draw_figure(np.zeros_like(init_rgb), 'mask0') | |
self.ctx['fig2'] = draw_figure(np.zeros_like(init_rgb), 'mask1') | |
self.ctx['fig3'] = draw_figure(np.zeros_like(init_rgb), 'mask2') | |
self.ctx['fig_seg_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Masked image in Training') | |
self.ctx['fig_sam_mask'] = draw_figure(np.zeros_like(init_rgb), 'SAM Mask with Prompts in Training') | |
self.ctx['fig_masked_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Masked RGB') | |
self.ctx['fig_seged_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Seged RGB') | |
app = dash.Dash( | |
__name__, meta_tags=[{"name": "viewport", "content": "width=device-width"}] | |
) | |
app.layout = html.Div( | |
style={"height": "100%"}, | |
children=[ | |
html.Div(className="container", children=[ | |
html.Div(className="row", children=[ | |
html.Div(className="two columns",style={"padding-bottom": "5%"},children=[ | |
html.Div([html.H3(['SAM Init'])]), | |
html.Br(), | |
html.H5('Prompt Type:'), | |
html.Div([ | |
dcc.Dropdown( | |
id = 'prompt_type', | |
options = [{'label': 'Points', 'value': 'point'}, | |
{'label': 'Text', 'value': 'text'},], | |
value = 'point'), | |
html.Div(id = 'output-prompt_type') | |
]), | |
html.Br(), | |
html.H5('Point Prompts:'), | |
html.Button('Clear Points', id='btn-nclicks-clear', n_clicks=0), | |
html.Br(), | |
html.H5('Text Prompt:'), | |
html.Div([ | |
dcc.Input(id='input-text-state', type='text', value='none'), | |
html.Button(id='submit-button-state', n_clicks=0, children='Generate'), | |
html.Div(id='output-state-text') | |
]), | |
html.Br(), | |
html.H5('Please select the mask:'), | |
html.Div([ | |
dcc.RadioItems(['mask0', 'mask1', 'mask2'], id='sel_mask_id', value=None) | |
], style={'display': 'flex'}), | |
html.Br(), | |
html.H5(id='container-sel-mask'), | |
]), | |
html.Div(className="ten columns",children=[ | |
html.Div(children=[ | |
dcc.Graph(id='main_image', figure=self.ctx['fig0']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
html.Div(children=[ | |
dcc.Graph(id='mask0', figure=self.ctx['fig1']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
html.Div(children=[ | |
dcc.Graph(id='mask1', figure=self.ctx['fig2']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
html.Div(children=[ | |
dcc.Graph(id='mask2', figure=self.ctx['fig3']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
]) | |
]) | |
]), | |
html.Div(className="container", children=[ | |
html.Div(className="row", children=[ | |
html.Div(className="two columns",style={"padding-bottom": "5%"},children=[ | |
html.Div([html.H3(['SA3D Training'])]), | |
html.Br(), | |
html.Button('Start Training', id='btn-nclicks-training', n_clicks=0), | |
html.Div(id='container-button-training', style={'display': 'inline-block'}), | |
]), | |
html.Div(className="ten columns",children=[ | |
html.Div(children=[ | |
dcc.Graph(id='seg_rgb', figure=self.ctx['fig_seg_rgb']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
html.Div(children=[ | |
dcc.Graph(id='sam_mask', figure=self.ctx['fig_sam_mask']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
]), | |
dcc.Interval( | |
id='interval-component', | |
interval=1*1000, # in milliseconds | |
n_intervals=0), | |
]) | |
]), | |
html.Div(className="container", children=[ | |
html.Div(className="row", children=[ | |
html.Div(className="two columns",style={"padding-bottom": "5%"},children=[ | |
html.Div([html.H3(['SA3D Rendering Results'])]), | |
html.Br(), | |
]), | |
html.Div(className="ten columns",children=[ | |
html.Div(children=[ | |
dcc.Graph(id='masked_rgb', figure=self.ctx['fig_masked_rgb']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
html.Div(children=[ | |
dcc.Graph(id='seged_rgb', figure=self.ctx['fig_seged_rgb']) | |
], style={'display': 'inline-block', 'width': '40%'}), | |
]), | |
]) | |
]) | |
]) | |
def update_prompt_type(value): | |
self.ctx['prompt_type'] = value | |
if value != 'point': | |
ctx['click'] = [] | |
ctx['num_clicks'] = 0 | |
return f"Type {value} is chosen" | |
def update_prompt(clickData, btn_point, btn_text, text): | |
''' | |
update mask | |
''' | |
if self.ctx['prompt_type'] == 'point': | |
if clickData is None and btn_point == self.ctx['btn_clear']: | |
raise PreventUpdate | |
if btn_point > self.ctx['btn_clear']: | |
self.ctx['btn_clear'] += 1 | |
ctx['click'] = [] | |
ctx['num_clicks'] = 0 | |
return self.ctx['fig0'], self.ctx['fig1'], self.ctx['fig2'], self.ctx['fig3'], 'none' | |
ctx['num_clicks'] += 1 | |
ctx['click'].append(np.array([clickData['points'][0]['x'], clickData['points'][0]['y']])) | |
ctx['saved_click'] = np.stack(ctx['click']) | |
masks, fig0, fig1, fig2, fig3 = query(ctx['saved_click']) | |
ctx['masks'] = masks | |
return fig0, fig1, fig2, fig3, 'none' | |
elif self.ctx['prompt_type'] == 'text': | |
if btn_text > self.ctx['btn_text']: | |
self.ctx['btn_text'] += 1 | |
self.ctx['text'] = text | |
masks, fig0, fig1, fig2, fig3 = query(points=None, text=text) | |
ctx['masks'] = masks | |
return fig0, fig1, fig2, fig3, u''' | |
Input text is "{}" | |
'''.format(text) | |
else: | |
raise PreventUpdate | |
else: | |
raise NotImplementedError | |
def update_graph(radio_items): | |
if radio_items == 'mask0': | |
ctx['select_mask_id'] = 0 | |
return html.Div("you select mask0") | |
elif radio_items == 'mask1': | |
ctx['select_mask_id'] = 1 | |
return html.Div("you select mask1") | |
elif radio_items == 'mask2': | |
ctx['select_mask_id'] = 2 | |
return html.Div("you select mask2") | |
else: | |
raise PreventUpdate | |
def displaySeg(n): | |
if self.ctx['show_rgb']: | |
self.ctx['show_rgb'] = False | |
fig_seg_rgb = draw_figure(self.ctx['fig_seg_rgb'], 'Masked image in Training') | |
fig_sam_mask = draw_figure(self.ctx['fig_sam_mask'], 'SAM Mask with Prompts in Training') | |
return fig_seg_rgb, fig_sam_mask | |
else: | |
raise PreventUpdate | |
def start_training(btn): | |
if btn < 1: | |
return html.Div("Press to start training"), self.ctx['fig_masked_rgb'], self.ctx['fig_seged_rgb'] | |
else: | |
# optim in the first view | |
self.Seg3d.train_step(self.train_idx, sam_mask=ctx['masks'][ctx['select_mask_id']]) | |
self.train_idx += 1 | |
# cross-view training | |
while True: | |
rgb, sam_prompt, is_finished = self.Seg3d.train_step(self.train_idx) | |
self.train_idx += 1 | |
self.ctx['fig_seg_rgb'] = rgb | |
self.ctx['fig_sam_mask'] = sam_prompt | |
self.ctx['show_rgb'] = True | |
if is_finished: | |
break | |
self.Seg3d.save_ckpt() | |
masked_rgb, seged_rgb = self.Seg3d.render_test() | |
fig_masked_rgb = draw_figure(masked_rgb, 'Masked RGB', animation_frame=0) | |
fig_seged_rgb = draw_figure(seged_rgb, 'Seged RGB', animation_frame=0) | |
return html.Div("Train Stage Finished! Press Ctrl+C to Exit!"), fig_masked_rgb, fig_seged_rgb | |
app.run_server(debug=self.debug) | |
if __name__ == '__main__': | |
from segment_anything import (SamAutomaticMaskGenerator, SamPredictor, | |
sam_model_registry) | |
class Sam_predictor(): | |
def __init__(self, device): | |
sam_checkpoint = "./dependencies/sam_ckpt/sam_vit_h_4b8939.pth" | |
model_type = "vit_h" | |
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device) | |
self.predictor = SamPredictor(self.sam) | |
print('sam inited!') | |
# pass | |
def forward(self, points, multimask_output=True, return_logits=False): | |
# self.predictor.set_image(image) | |
# input_point = np.array([[x, y], [x + 1, y + 1]]) # TODO, add interactive mode | |
input_point = points | |
input_label = np.ones(len(input_point)) | |
masks, scores, logits = self.predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=multimask_output, | |
return_logits=return_logits | |
) | |
return masks | |
image = cv2.cvtColor(cv2.imread('data/nerf_llff_data(NVOS)/fern/images_4/image000.png'), cv2.COLOR_BGR2RGB) | |
sam_pred = Sam_predictor(torch.device('cuda')) | |
sam_pred.predictor.set_image(image) | |
video = np.stack(imageio.mimread('logs/llff/fern/render_train_coarse_segmentation_gui/video.rgbseg_gui.mp4')) | |
gui = Sam3dGUI(None, debug=True) | |
gui.ctx['cur_img'] = image | |
gui.ctx['video'] = video | |
gui.run_app(sam_pred.predictor, gui.ctx, image) | |