Our3D / lib /gui.py
yansong1616's picture
Upload 384 files
b177539 verified
raw
history blame
16.1 kB
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import cv2
import imageio
import time
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import torch
import dash
from dash import Dash, Input, Output, dcc, html, State
from dash.exceptions import PreventUpdate
from .self_prompting import grounding_dino_prompt
def mark_image(_img, points):
assert(len(points) > 0)
img = _img.copy()
r = 10
mark_color = np.array([255, 0, 0]).reshape(1, 1, 3)
for i in range(len(points)):
point = points[i]
img[point[1]-r:point[1]+r+1, point[0]-r:point[0]+r+1] = mark_color
return img
def draw_figure(fig, title, animation_frame=None):
fig = px.imshow(fig, animation_frame=animation_frame)
if animation_frame is not None:
# fig.update_layout(sliders = [{'visible': False}])
fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 33
fig.update_layout(title_text=title, showlegend=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
return fig
class Sam3dGUI:
def __init__(self, Seg3d, debug=False):
ctx = {
'num_clicks': 0,
'click': [],
'cur_img': None,
'btn_clear': 0,
'btn_text': 0,
'prompt_type': 'point',
'show_rgb': False
}
self.ctx = ctx
self.Seg3d = Seg3d
self.debug = debug
self.train_idx = 0
def run(self):
init_rgb = self.Seg3d.init_model()
self.ctx['cur_img'] = init_rgb
self.run_app(sam_pred=self.Seg3d.predictor, ctx=self.ctx, init_rgb=init_rgb)
def run_app(self, sam_pred, ctx, init_rgb):
'''
run dash app
'''
def query(points=None, text=None):
with torch.no_grad():
if text is None:
input_point = points
input_label = np.ones(len(input_point))
masks, scores, logits = sam_pred.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
elif points is None:
input_boxes = grounding_dino_prompt(ctx['cur_img'], text)
boxes = torch.tensor(input_boxes)[0:1].cuda()
transformed_boxes = sam_pred.transform.apply_boxes_torch(boxes, ctx['cur_img'].shape[:2])
masks, scores, logits = sam_pred.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=True,
)
masks = masks[0].cpu().numpy()
else:
raise NotImplementedError
fig1 = (255*masks[0, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8)
fig2 = (255*masks[1, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8)
fig3 = (255*masks[2, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8)
fig1 = draw_figure(fig1, 'mask0')
fig2 = draw_figure(fig2, 'mask1')
fig3 = draw_figure(fig3, 'mask2')
if text is None:
fig0 = mark_image(ctx['cur_img'], points)
else:
fig0 = ctx['cur_img']
fig0 = draw_figure(fig0, 'original_image')
return masks, fig0, fig1, fig2, fig3
# _, fig0, fig1, fig2, fig3, desc = query(np.array([[100, 100], [101, 101]]))
self.ctx['fig0'] = draw_figure(init_rgb, 'original_image')
self.ctx['fig1'] = draw_figure(np.zeros_like(init_rgb), 'mask0')
self.ctx['fig2'] = draw_figure(np.zeros_like(init_rgb), 'mask1')
self.ctx['fig3'] = draw_figure(np.zeros_like(init_rgb), 'mask2')
self.ctx['fig_seg_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Masked image in Training')
self.ctx['fig_sam_mask'] = draw_figure(np.zeros_like(init_rgb), 'SAM Mask with Prompts in Training')
self.ctx['fig_masked_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Masked RGB')
self.ctx['fig_seged_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Seged RGB')
app = dash.Dash(
__name__, meta_tags=[{"name": "viewport", "content": "width=device-width"}]
)
app.layout = html.Div(
style={"height": "100%"},
children=[
html.Div(className="container", children=[
html.Div(className="row", children=[
html.Div(className="two columns",style={"padding-bottom": "5%"},children=[
html.Div([html.H3(['SAM Init'])]),
html.Br(),
html.H5('Prompt Type:'),
html.Div([
dcc.Dropdown(
id = 'prompt_type',
options = [{'label': 'Points', 'value': 'point'},
{'label': 'Text', 'value': 'text'},],
value = 'point'),
html.Div(id = 'output-prompt_type')
]),
html.Br(),
html.H5('Point Prompts:'),
html.Button('Clear Points', id='btn-nclicks-clear', n_clicks=0),
html.Br(),
html.H5('Text Prompt:'),
html.Div([
dcc.Input(id='input-text-state', type='text', value='none'),
html.Button(id='submit-button-state', n_clicks=0, children='Generate'),
html.Div(id='output-state-text')
]),
html.Br(),
html.H5('Please select the mask:'),
html.Div([
dcc.RadioItems(['mask0', 'mask1', 'mask2'], id='sel_mask_id', value=None)
], style={'display': 'flex'}),
html.Br(),
html.H5(id='container-sel-mask'),
]),
html.Div(className="ten columns",children=[
html.Div(children=[
dcc.Graph(id='main_image', figure=self.ctx['fig0'])
], style={'display': 'inline-block', 'width': '40%'}),
html.Div(children=[
dcc.Graph(id='mask0', figure=self.ctx['fig1'])
], style={'display': 'inline-block', 'width': '40%'}),
html.Div(children=[
dcc.Graph(id='mask1', figure=self.ctx['fig2'])
], style={'display': 'inline-block', 'width': '40%'}),
html.Div(children=[
dcc.Graph(id='mask2', figure=self.ctx['fig3'])
], style={'display': 'inline-block', 'width': '40%'}),
])
])
]),
html.Div(className="container", children=[
html.Div(className="row", children=[
html.Div(className="two columns",style={"padding-bottom": "5%"},children=[
html.Div([html.H3(['SA3D Training'])]),
html.Br(),
html.Button('Start Training', id='btn-nclicks-training', n_clicks=0),
html.Div(id='container-button-training', style={'display': 'inline-block'}),
]),
html.Div(className="ten columns",children=[
html.Div(children=[
dcc.Graph(id='seg_rgb', figure=self.ctx['fig_seg_rgb'])
], style={'display': 'inline-block', 'width': '40%'}),
html.Div(children=[
dcc.Graph(id='sam_mask', figure=self.ctx['fig_sam_mask'])
], style={'display': 'inline-block', 'width': '40%'}),
]),
dcc.Interval(
id='interval-component',
interval=1*1000, # in milliseconds
n_intervals=0),
])
]),
html.Div(className="container", children=[
html.Div(className="row", children=[
html.Div(className="two columns",style={"padding-bottom": "5%"},children=[
html.Div([html.H3(['SA3D Rendering Results'])]),
html.Br(),
]),
html.Div(className="ten columns",children=[
html.Div(children=[
dcc.Graph(id='masked_rgb', figure=self.ctx['fig_masked_rgb'])
], style={'display': 'inline-block', 'width': '40%'}),
html.Div(children=[
dcc.Graph(id='seged_rgb', figure=self.ctx['fig_seged_rgb'])
], style={'display': 'inline-block', 'width': '40%'}),
]),
])
])
])
@app.callback(Output('output-prompt_type', 'children'), [Input('prompt_type', 'value')])
def update_prompt_type(value):
self.ctx['prompt_type'] = value
if value != 'point':
ctx['click'] = []
ctx['num_clicks'] = 0
return f"Type {value} is chosen"
@app.callback(
Output('main_image', 'figure'),
Output('mask0', 'figure'),
Output('mask1', 'figure'),
Output('mask2', 'figure'),
Output('output-state-text', 'children'),
Input('main_image', 'clickData'),
Input('btn-nclicks-clear', 'n_clicks'),
Input('submit-button-state', 'n_clicks'),
State('input-text-state', 'value')
)
def update_prompt(clickData, btn_point, btn_text, text):
'''
update mask
'''
if self.ctx['prompt_type'] == 'point':
if clickData is None and btn_point == self.ctx['btn_clear']:
raise PreventUpdate
if btn_point > self.ctx['btn_clear']:
self.ctx['btn_clear'] += 1
ctx['click'] = []
ctx['num_clicks'] = 0
return self.ctx['fig0'], self.ctx['fig1'], self.ctx['fig2'], self.ctx['fig3'], 'none'
ctx['num_clicks'] += 1
ctx['click'].append(np.array([clickData['points'][0]['x'], clickData['points'][0]['y']]))
ctx['saved_click'] = np.stack(ctx['click'])
masks, fig0, fig1, fig2, fig3 = query(ctx['saved_click'])
ctx['masks'] = masks
return fig0, fig1, fig2, fig3, 'none'
elif self.ctx['prompt_type'] == 'text':
if btn_text > self.ctx['btn_text']:
self.ctx['btn_text'] += 1
self.ctx['text'] = text
masks, fig0, fig1, fig2, fig3 = query(points=None, text=text)
ctx['masks'] = masks
return fig0, fig1, fig2, fig3, u'''
Input text is "{}"
'''.format(text)
else:
raise PreventUpdate
else:
raise NotImplementedError
@app.callback(
Output("container-sel-mask", 'children'),
Input("sel_mask_id", 'value')
)
def update_graph(radio_items):
if radio_items == 'mask0':
ctx['select_mask_id'] = 0
return html.Div("you select mask0")
elif radio_items == 'mask1':
ctx['select_mask_id'] = 1
return html.Div("you select mask1")
elif radio_items == 'mask2':
ctx['select_mask_id'] = 2
return html.Div("you select mask2")
else:
raise PreventUpdate
@app.callback(
Output('seg_rgb', 'figure'),
Output('sam_mask', 'figure'),
Input('interval-component', 'n_intervals')
)
def displaySeg(n):
if self.ctx['show_rgb']:
self.ctx['show_rgb'] = False
fig_seg_rgb = draw_figure(self.ctx['fig_seg_rgb'], 'Masked image in Training')
fig_sam_mask = draw_figure(self.ctx['fig_sam_mask'], 'SAM Mask with Prompts in Training')
return fig_seg_rgb, fig_sam_mask
else:
raise PreventUpdate
@app.callback(
Output('container-button-training', 'children'),
Output('masked_rgb', 'figure'),
Output('seged_rgb', 'figure'),
Input('btn-nclicks-training', 'n_clicks')
)
def start_training(btn):
if btn < 1:
return html.Div("Press to start training"), self.ctx['fig_masked_rgb'], self.ctx['fig_seged_rgb']
else:
# optim in the first view
self.Seg3d.train_step(self.train_idx, sam_mask=ctx['masks'][ctx['select_mask_id']])
self.train_idx += 1
# cross-view training
while True:
rgb, sam_prompt, is_finished = self.Seg3d.train_step(self.train_idx)
self.train_idx += 1
self.ctx['fig_seg_rgb'] = rgb
self.ctx['fig_sam_mask'] = sam_prompt
self.ctx['show_rgb'] = True
if is_finished:
break
self.Seg3d.save_ckpt()
masked_rgb, seged_rgb = self.Seg3d.render_test()
fig_masked_rgb = draw_figure(masked_rgb, 'Masked RGB', animation_frame=0)
fig_seged_rgb = draw_figure(seged_rgb, 'Seged RGB', animation_frame=0)
return html.Div("Train Stage Finished! Press Ctrl+C to Exit!"), fig_masked_rgb, fig_seged_rgb
app.run_server(debug=self.debug)
if __name__ == '__main__':
from segment_anything import (SamAutomaticMaskGenerator, SamPredictor,
sam_model_registry)
class Sam_predictor():
def __init__(self, device):
sam_checkpoint = "./dependencies/sam_ckpt/sam_vit_h_4b8939.pth"
model_type = "vit_h"
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
self.predictor = SamPredictor(self.sam)
print('sam inited!')
# pass
def forward(self, points, multimask_output=True, return_logits=False):
# self.predictor.set_image(image)
# input_point = np.array([[x, y], [x + 1, y + 1]]) # TODO, add interactive mode
input_point = points
input_label = np.ones(len(input_point))
masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=multimask_output,
return_logits=return_logits
)
return masks
image = cv2.cvtColor(cv2.imread('data/nerf_llff_data(NVOS)/fern/images_4/image000.png'), cv2.COLOR_BGR2RGB)
sam_pred = Sam_predictor(torch.device('cuda'))
sam_pred.predictor.set_image(image)
video = np.stack(imageio.mimread('logs/llff/fern/render_train_coarse_segmentation_gui/video.rgbseg_gui.mp4'))
gui = Sam3dGUI(None, debug=True)
gui.ctx['cur_img'] = image
gui.ctx['video'] = video
gui.run_app(sam_pred.predictor, gui.ctx, image)