Spaces:
Running
on
L4
Running
on
L4
# https://huggingface.co/DragGan/DragGan-Models | |
# https://arxiv.org/abs/2305.10973 | |
import os | |
import os.path as osp | |
from argparse import ArgumentParser | |
from functools import partial | |
from pathlib import Path | |
import time | |
import psutil | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import dnnlib | |
from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, | |
get_latest_points_pair, get_valid_mask, | |
on_change_single_global_state) | |
from viz.renderer import Renderer, add_watermark_np | |
# download models from Hugging Face hub | |
from huggingface_hub import snapshot_download | |
model_dir = Path('./checkpoints') | |
snapshot_download('DragGan/DragGan-Models', | |
repo_type='model', local_dir=model_dir) | |
parser = ArgumentParser() | |
parser.add_argument('--share', action='store_true') | |
parser.add_argument('--cache-dir', type=str, default='./checkpoints') | |
args = parser.parse_args() | |
cache_dir = args.cache_dir | |
device = 'cuda' | |
IS_SPACE = "DragGan/DragGan" in os.environ.get('SPACE_ID', '') | |
TIMEOUT = 80 | |
def reverse_point_pairs(points): | |
new_points = [] | |
for p in points: | |
new_points.append([p[1], p[0]]) | |
return new_points | |
def clear_state(global_state, target=None): | |
"""Clear target history state from global_state | |
If target is not defined, points and mask will be both removed. | |
1. set global_state['points'] as empty dict | |
2. set global_state['mask'] as full-one mask. | |
""" | |
if target is None: | |
target = ['point', 'mask'] | |
if not isinstance(target, list): | |
target = [target] | |
if 'point' in target: | |
global_state['points'] = dict() | |
print('Clear Points State!') | |
if 'mask' in target: | |
image_raw = global_state["images"]["image_raw"] | |
global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), | |
dtype=np.uint8) | |
print('Clear mask State!') | |
return global_state | |
def init_images(global_state): | |
"""This function is called only ones with Gradio App is started. | |
0. pre-process global_state, unpack value from global_state of need | |
1. Re-init renderer | |
2. run `renderer._render_drag_impl` with `is_drag=False` to generate | |
new image | |
3. Assign images to global state and re-generate mask | |
""" | |
if isinstance(global_state, gr.State): | |
state = global_state.value | |
else: | |
state = global_state | |
state['renderer'].init_network( | |
state['generator_params'], # res | |
valid_checkpoints_dict[state['pretrained_weight']], # pkl | |
state['params']['seed'], # w0_seed, | |
None, # w_load | |
state['params']['latent_space'] == 'w+', # w_plus | |
'const', | |
state['params']['trunc_psi'], # trunc_psi, | |
state['params']['trunc_cutoff'], # trunc_cutoff, | |
None, # input_transform | |
state['params']['lr'] # lr, | |
) | |
state['renderer']._render_drag_impl(state['generator_params'], | |
is_drag=False, | |
to_pil=True) | |
init_image = state['generator_params'].image | |
state['images']['image_orig'] = init_image | |
state['images']['image_raw'] = init_image | |
state['images']['image_show'] = Image.fromarray( | |
add_watermark_np(np.array(init_image))) | |
state['mask'] = np.ones((init_image.size[1], init_image.size[0]), | |
dtype=np.uint8) | |
return global_state | |
def update_image_draw(image, points, mask, show_mask, global_state=None): | |
image_draw = draw_points_on_image(image, points) | |
if show_mask and mask is not None and not (mask == 0).all() and not ( | |
mask == 1).all(): | |
image_draw = draw_mask_on_image(image_draw, mask) | |
image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) | |
if global_state is not None: | |
global_state['images']['image_show'] = image_draw | |
return image_draw | |
def preprocess_mask_info(global_state, image): | |
"""Function to handle mask information. | |
1. last_mask is None: Do not need to change mask, return mask | |
2. last_mask is not None: | |
2.1 global_state is remove_mask: | |
2.2 global_state is add_mask: | |
""" | |
if isinstance(image, dict): | |
last_mask = get_valid_mask(image['mask']) | |
else: | |
last_mask = None | |
mask = global_state['mask'] | |
# mask in global state is a placeholder with all 1. | |
if (mask == 1).all(): | |
mask = last_mask | |
# last_mask = global_state['last_mask'] | |
editing_mode = global_state['editing_state'] | |
if last_mask is None: | |
return global_state | |
if editing_mode == 'remove_mask': | |
updated_mask = np.clip(mask - last_mask, 0, 1) | |
print(f'Last editing_state is {editing_mode}, do remove.') | |
elif editing_mode == 'add_mask': | |
updated_mask = np.clip(mask + last_mask, 0, 1) | |
print(f'Last editing_state is {editing_mode}, do add.') | |
else: | |
updated_mask = mask | |
print(f'Last editing_state is {editing_mode}, ' | |
'do nothing to mask.') | |
global_state['mask'] = updated_mask | |
# global_state['last_mask'] = None # clear buffer | |
return global_state | |
def print_memory_usage(): | |
# Print system memory usage | |
print(f"System memory usage: {psutil.virtual_memory().percent}%") | |
# Print GPU memory usage | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB") | |
print( | |
f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB") | |
device_properties = torch.cuda.get_device_properties(device) | |
available_memory = device_properties.total_memory - \ | |
torch.cuda.max_memory_allocated() | |
print(f"Available GPU memory: {available_memory / 1e9} GB") | |
else: | |
print("No GPU available") | |
# filter large models running on SPACES | |
allowed_checkpoints = [] # all checkpoints | |
if IS_SPACE: | |
allowed_checkpoints = ["stylegan_human_v2_512.pkl", | |
"stylegan2_dogs_1024_pytorch.pkl"] | |
valid_checkpoints_dict = { | |
f.name.split('.')[0]: str(f) | |
for f in Path(cache_dir).glob('*.pkl') | |
if f.name in allowed_checkpoints or not IS_SPACE | |
} | |
print('Valid checkpoint file:') | |
print(valid_checkpoints_dict) | |
init_pkl = 'stylegan_human_v2_512' | |
with gr.Blocks() as app: | |
gr.Markdown(""" | |
# DragGAN - Drag Your GAN | |
## Interactive Point-based Manipulation on the Generative Image Manifold | |
### Unofficial Gradio Demo | |
**Due to high demand, only one model can be run at a time, or you can duplicate the space and run your own copy.** | |
<a href="https://huggingface.co/spaces/radames/DragGan?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> | |
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p> | |
* Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN) | |
* Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) Β© [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic) | |
""") | |
# renderer = Renderer() | |
global_state = gr.State({ | |
"images": { | |
# image_orig: the original image, change with seed/model is changed | |
# image_raw: image with mask and points, change durning optimization | |
# image_show: image showed on screen | |
}, | |
"temporal_params": { | |
# stop | |
}, | |
'mask': | |
None, # mask for visualization, 1 for editing and 0 for unchange | |
'last_mask': None, # last edited mask | |
'show_mask': True, # add button | |
"generator_params": dnnlib.EasyDict(), | |
"params": { | |
"seed": int(np.random.randint(0, 2**32 - 1)), | |
"motion_lambda": 20, | |
"r1_in_pixels": 3, | |
"r2_in_pixels": 12, | |
"magnitude_direction_in_pixels": 1.0, | |
"latent_space": "w+", | |
"trunc_psi": 0.7, | |
"trunc_cutoff": None, | |
"lr": 0.001, | |
}, | |
"device": device, | |
"draw_interval": 1, | |
"renderer": Renderer(disable_timing=True), | |
"points": {}, | |
"curr_point": None, | |
"curr_type_point": "start", | |
'editing_state': 'add_points', | |
'pretrained_weight': init_pkl | |
}) | |
# init image | |
global_state = init_images(global_state) | |
with gr.Row(): | |
with gr.Row(): | |
# Left --> tools | |
with gr.Column(scale=3): | |
# Pickle | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
gr.Markdown(value='Pickle', show_label=False) | |
with gr.Column(scale=4, min_width=10): | |
form_pretrained_dropdown = gr.Dropdown( | |
choices=list(valid_checkpoints_dict.keys()), | |
label="Pretrained Model", | |
value=init_pkl, | |
) | |
# Latent | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
gr.Markdown(value='Latent', show_label=False) | |
with gr.Column(scale=4, min_width=10): | |
form_seed_number = gr.Slider( | |
mininium=0, | |
maximum=2**32-1, | |
step=1, | |
value=global_state.value['params']['seed'], | |
interactive=True, | |
# randomize=True, | |
label="Seed", | |
) | |
form_lr_number = gr.Number( | |
value=global_state.value["params"]["lr"], | |
interactive=True, | |
label="Step Size") | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=10): | |
form_reset_image = gr.Button("Reset Image") | |
with gr.Column(scale=3, min_width=10): | |
form_latent_space = gr.Radio( | |
['w', 'w+'], | |
value=global_state.value['params'] | |
['latent_space'], | |
interactive=True, | |
label='Latent space to optimize', | |
show_label=False, | |
) | |
# Drag | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
gr.Markdown(value='Drag', show_label=False) | |
with gr.Column(scale=4, min_width=10): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
enable_add_points = gr.Button('Add Points') | |
with gr.Column(scale=1, min_width=10): | |
undo_points = gr.Button('Reset Points') | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
form_start_btn = gr.Button("Start") | |
with gr.Column(scale=1, min_width=10): | |
form_stop_btn = gr.Button("Stop") | |
form_steps_number = gr.Number(value=0, | |
label="Steps", | |
interactive=False) | |
# Mask | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
gr.Markdown(value='Mask', show_label=False) | |
with gr.Column(scale=4, min_width=10): | |
enable_add_mask = gr.Button('Edit Flexible Area') | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
form_reset_mask_btn = gr.Button("Reset mask") | |
with gr.Column(scale=1, min_width=10): | |
show_mask = gr.Checkbox( | |
label='Show Mask', | |
value=global_state.value['show_mask'], | |
show_label=False) | |
with gr.Row(): | |
form_lambda_number = gr.Number( | |
value=global_state.value["params"] | |
["motion_lambda"], | |
interactive=True, | |
label="Lambda", | |
) | |
form_draw_interval_number = gr.Number( | |
value=global_state.value["draw_interval"], | |
label="Draw Interval (steps)", | |
interactive=True, | |
visible=False) | |
# Right --> Image | |
with gr.Column(scale=8): | |
form_image = ImageMask( | |
value=global_state.value['images']['image_show'], | |
brush_radius=20).style( | |
width=768, | |
height=768) # NOTE: hard image size code here. | |
gr.Markdown(""" | |
## Quick Start | |
1. Select desired `Pretrained Model` and adjust `Seed` to generate an | |
initial image. | |
2. Click on image to add control points. | |
3. Click `Start` and enjoy it! | |
## Advance Usage | |
1. Change `Step Size` to adjust learning rate in drag optimization. | |
2. Select `w` or `w+` to change latent space to optimize: | |
* Optimize on `w` space may cause greater influence to the image. | |
* Optimize on `w+` space may work slower than `w`, but usually achieve | |
better results. | |
* Note that changing the latent space will reset the image, points and | |
mask (this has the same effect as `Reset Image` button). | |
3. Click `Edit Flexible Area` to create a mask and constrain the | |
unmasked region to remain unchanged. | |
""") | |
gr.HTML(""" | |
<style> | |
.container { | |
position: absolute; | |
height: 50px; | |
text-align: center; | |
line-height: 50px; | |
width: 100%; | |
} | |
</style> | |
<div class="container"> | |
Gradio demo supported by | |
<img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;"> | |
<a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a> | |
</div> | |
""") | |
# Network & latents tab listeners | |
def on_change_pretrained_dropdown(pretrained_value, global_state): | |
"""Function to handle model change. | |
1. Set pretrained value to global_state | |
2. Re-init images and clear all states | |
""" | |
global_state['pretrained_weight'] = pretrained_value | |
init_images(global_state) | |
clear_state(global_state) | |
return global_state, global_state["images"]['image_show'] | |
form_pretrained_dropdown.change( | |
on_change_pretrained_dropdown, | |
inputs=[form_pretrained_dropdown, global_state], | |
outputs=[global_state, form_image], | |
queue=True, | |
) | |
def on_click_reset_image(global_state): | |
"""Reset image to the original one and clear all states | |
1. Re-init images | |
2. Clear all states | |
""" | |
init_images(global_state) | |
clear_state(global_state) | |
return global_state, global_state['images']['image_show'] | |
form_reset_image.click( | |
on_click_reset_image, | |
inputs=[global_state], | |
outputs=[global_state, form_image], | |
queue=False, | |
) | |
# Update parameters | |
def on_change_update_image_seed(seed, global_state): | |
"""Function to handle generation seed change. | |
1. Set seed to global_state | |
2. Re-init images and clear all states | |
""" | |
global_state["params"]["seed"] = int(seed) | |
init_images(global_state) | |
clear_state(global_state) | |
return global_state, global_state['images']['image_show'] | |
form_seed_number.change( | |
on_change_update_image_seed, | |
inputs=[form_seed_number, global_state], | |
outputs=[global_state, form_image], | |
) | |
def on_click_latent_space(latent_space, global_state): | |
"""Function to reset latent space to optimize. | |
NOTE: this function we reset the image and all controls | |
1. Set latent-space to global_state | |
2. Re-init images and clear all state | |
""" | |
global_state['params']['latent_space'] = latent_space | |
init_images(global_state) | |
clear_state(global_state) | |
return global_state, global_state['images']['image_show'] | |
form_latent_space.change(on_click_latent_space, | |
inputs=[form_latent_space, global_state], | |
outputs=[global_state, form_image]) | |
# ==== Params | |
form_lambda_number.change( | |
partial(on_change_single_global_state, ["params", "motion_lambda"]), | |
inputs=[form_lambda_number, global_state], | |
outputs=[global_state], | |
) | |
def on_change_lr(lr, global_state): | |
if lr == 0: | |
print('lr is 0, do nothing.') | |
return global_state | |
else: | |
global_state["params"]["lr"] = lr | |
renderer = global_state['renderer'] | |
renderer.update_lr(lr) | |
print('New optimizer: ') | |
print(renderer.w_optim) | |
return global_state | |
form_lr_number.change( | |
on_change_lr, | |
inputs=[form_lr_number, global_state], | |
outputs=[global_state], | |
queue=False, | |
) | |
def on_click_start(global_state, image): | |
p_in_pixels = [] | |
t_in_pixels = [] | |
valid_points = [] | |
# handle of start drag in mask editing mode | |
global_state = preprocess_mask_info(global_state, image) | |
# Prepare the points for the inference | |
if len(global_state["points"]) == 0: | |
# yield on_click_start_wo_points(global_state, image) | |
image_raw = global_state['images']['image_raw'] | |
update_image_draw( | |
image_raw, | |
global_state['points'], | |
global_state['mask'], | |
global_state['show_mask'], | |
global_state, | |
) | |
yield ( | |
global_state, | |
0, | |
global_state['images']['image_show'], | |
# gr.File.update(visible=False), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
# latent space | |
gr.Radio.update(interactive=True), | |
gr.Button.update(interactive=True), | |
# NOTE: disable stop button | |
gr.Button.update(interactive=False), | |
# update other comps | |
gr.Dropdown.update(interactive=True), | |
gr.Number.update(interactive=True), | |
gr.Number.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Checkbox.update(interactive=True), | |
# gr.Number.update(interactive=True), | |
gr.Number.update(interactive=True), | |
) | |
else: | |
# Transform the points into torch tensors | |
for key_point, point in global_state["points"].items(): | |
try: | |
p_start = point.get("start_temp", point["start"]) | |
p_end = point["target"] | |
if p_start is None or p_end is None: | |
continue | |
except KeyError: | |
continue | |
p_in_pixels.append(p_start) | |
t_in_pixels.append(p_end) | |
valid_points.append(key_point) | |
mask = torch.tensor(global_state['mask']).float() | |
drag_mask = 1 - mask | |
renderer: Renderer = global_state["renderer"] | |
global_state['temporal_params']['stop'] = False | |
global_state['editing_state'] = 'running' | |
# reverse points order | |
p_to_opt = reverse_point_pairs(p_in_pixels) | |
t_to_opt = reverse_point_pairs(t_in_pixels) | |
print('Running with:') | |
print(f' Source: {p_in_pixels}') | |
print(f' Target: {t_in_pixels}') | |
step_idx = 0 | |
last_time = time.time() | |
while True: | |
print_memory_usage() | |
# add a TIMEOUT break | |
print(f'Running time: {time.time() - last_time}') | |
if IS_SPACE and time.time() - last_time > TIMEOUT: | |
print('Timeout break!') | |
break | |
if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]: | |
break | |
# do drage here! | |
renderer._render_drag_impl( | |
global_state['generator_params'], | |
p_to_opt, # point | |
t_to_opt, # target | |
drag_mask, # mask, | |
global_state['params']['motion_lambda'], # lambda_mask | |
reg=0, | |
feature_idx=5, # NOTE: do not support change for now | |
r1=global_state['params']['r1_in_pixels'], # r1 | |
r2=global_state['params']['r2_in_pixels'], # r2 | |
# random_seed = 0, | |
# noise_mode = 'const', | |
trunc_psi=global_state['params']['trunc_psi'], | |
# force_fp32 = False, | |
# layer_name = None, | |
# sel_channels = 3, | |
# base_channel = 0, | |
# img_scale_db = 0, | |
# img_normalize = False, | |
# untransform = False, | |
is_drag=True, | |
to_pil=True) | |
if step_idx % global_state['draw_interval'] == 0: | |
print('Current Source:') | |
for key_point, p_i, t_i in zip(valid_points, p_to_opt, | |
t_to_opt): | |
global_state["points"][key_point]["start_temp"] = [ | |
p_i[1], | |
p_i[0], | |
] | |
global_state["points"][key_point]["target"] = [ | |
t_i[1], | |
t_i[0], | |
] | |
start_temp = global_state["points"][key_point][ | |
"start_temp"] | |
print(f' {start_temp}') | |
image_result = global_state['generator_params']['image'] | |
image_draw = update_image_draw( | |
image_result, | |
global_state['points'], | |
global_state['mask'], | |
global_state['show_mask'], | |
global_state, | |
) | |
global_state['images']['image_raw'] = image_result | |
yield ( | |
global_state, | |
step_idx, | |
global_state['images']['image_show'], | |
# gr.File.update(visible=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
# latent space | |
gr.Radio.update(interactive=False), | |
gr.Button.update(interactive=False), | |
# enable stop button in loop | |
gr.Button.update(interactive=True), | |
# update other comps | |
gr.Dropdown.update(interactive=False), | |
gr.Number.update(interactive=False), | |
gr.Number.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Checkbox.update(interactive=False), | |
# gr.Number.update(interactive=False), | |
gr.Number.update(interactive=False), | |
) | |
# increate step | |
step_idx += 1 | |
image_result = global_state['generator_params']['image'] | |
global_state['images']['image_raw'] = image_result | |
image_draw = update_image_draw(image_result, | |
global_state['points'], | |
global_state['mask'], | |
global_state['show_mask'], | |
global_state) | |
# fp = NamedTemporaryFile(suffix=".png", delete=False) | |
# image_result.save(fp, "PNG") | |
global_state['editing_state'] = 'add_points' | |
yield ( | |
global_state, | |
0, # reset step to 0 after stop. | |
global_state['images']['image_show'], | |
# gr.File.update(visible=True, value=fp.name), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
# latent space | |
gr.Radio.update(interactive=True), | |
gr.Button.update(interactive=True), | |
# NOTE: disable stop button with loop finish | |
gr.Button.update(interactive=False), | |
# update other comps | |
gr.Dropdown.update(interactive=True), | |
gr.Number.update(interactive=True), | |
gr.Number.update(interactive=True), | |
gr.Checkbox.update(interactive=True), | |
gr.Number.update(interactive=True), | |
) | |
form_start_btn.click( | |
on_click_start, | |
inputs=[global_state, form_image], | |
outputs=[ | |
global_state, | |
form_steps_number, | |
form_image, | |
# form_download_result_file, | |
# >>> buttons | |
form_reset_image, | |
enable_add_points, | |
enable_add_mask, | |
undo_points, | |
form_reset_mask_btn, | |
form_latent_space, | |
form_start_btn, | |
form_stop_btn, | |
# <<< buttonm | |
# >>> inputs comps | |
form_pretrained_dropdown, | |
form_seed_number, | |
form_lr_number, | |
show_mask, | |
form_lambda_number, | |
], | |
) | |
def on_click_stop(global_state): | |
"""Function to handle stop button is clicked. | |
1. send a stop signal by set global_state["temporal_params"]["stop"] as True | |
2. Disable Stop button | |
""" | |
global_state["temporal_params"]["stop"] = True | |
return global_state, gr.Button.update(interactive=False) | |
form_stop_btn.click(on_click_stop, | |
inputs=[global_state], | |
outputs=[global_state, form_stop_btn], | |
queue=False) | |
form_draw_interval_number.change( | |
partial( | |
on_change_single_global_state, | |
"draw_interval", | |
map_transform=lambda x: int(x), | |
), | |
inputs=[form_draw_interval_number, global_state], | |
outputs=[global_state], | |
queue=False, | |
) | |
def on_click_remove_point(global_state): | |
choice = global_state["curr_point"] | |
del global_state["points"][choice] | |
choices = list(global_state["points"].keys()) | |
if len(choices) > 0: | |
global_state["curr_point"] = choices[0] | |
return ( | |
gr.Dropdown.update(choices=choices, value=choices[0]), | |
global_state, | |
) | |
# Mask | |
def on_click_reset_mask(global_state): | |
global_state['mask'] = np.ones( | |
( | |
global_state["images"]["image_raw"].size[1], | |
global_state["images"]["image_raw"].size[0], | |
), | |
dtype=np.uint8, | |
) | |
image_draw = update_image_draw(global_state['images']['image_raw'], | |
global_state['points'], | |
global_state['mask'], | |
global_state['show_mask'], global_state) | |
return global_state, image_draw | |
form_reset_mask_btn.click( | |
on_click_reset_mask, | |
inputs=[global_state], | |
outputs=[global_state, form_image], | |
) | |
# Image | |
def on_click_enable_draw(global_state, image): | |
"""Function to start add mask mode. | |
1. Preprocess mask info from last state | |
2. Change editing state to add_mask | |
3. Set curr image with points and mask | |
""" | |
global_state = preprocess_mask_info(global_state, image) | |
global_state['editing_state'] = 'add_mask' | |
image_raw = global_state['images']['image_raw'] | |
image_draw = update_image_draw(image_raw, global_state['points'], | |
global_state['mask'], True, | |
global_state) | |
return (global_state, | |
gr.Image.update(value=image_draw, interactive=True)) | |
def on_click_remove_draw(global_state, image): | |
"""Function to start remove mask mode. | |
1. Preprocess mask info from last state | |
2. Change editing state to remove_mask | |
3. Set curr image with points and mask | |
""" | |
global_state = preprocess_mask_info(global_state, image) | |
global_state['edinting_state'] = 'remove_mask' | |
image_raw = global_state['images']['image_raw'] | |
image_draw = update_image_draw(image_raw, global_state['points'], | |
global_state['mask'], True, | |
global_state) | |
return (global_state, | |
gr.Image.update(value=image_draw, interactive=True)) | |
enable_add_mask.click(on_click_enable_draw, | |
inputs=[global_state, form_image], | |
outputs=[ | |
global_state, | |
form_image, | |
], | |
queue=False) | |
def on_click_add_point(global_state, image: dict): | |
"""Function switch from add mask mode to add points mode. | |
1. Updaste mask buffer if need | |
2. Change global_state['editing_state'] to 'add_points' | |
3. Set current image with mask | |
""" | |
global_state = preprocess_mask_info(global_state, image) | |
global_state['editing_state'] = 'add_points' | |
mask = global_state['mask'] | |
image_raw = global_state['images']['image_raw'] | |
image_draw = update_image_draw(image_raw, global_state['points'], mask, | |
global_state['show_mask'], global_state) | |
return (global_state, | |
gr.Image.update(value=image_draw, interactive=False)) | |
enable_add_points.click(on_click_add_point, | |
inputs=[global_state, form_image], | |
outputs=[global_state, form_image], | |
queue=False) | |
def on_click_image(global_state, evt: gr.SelectData): | |
"""This function only support click for point selection | |
""" | |
xy = evt.index | |
if global_state['editing_state'] != 'add_points': | |
print(f'In {global_state["editing_state"]} state. ' | |
'Do not add points.') | |
return global_state, global_state['images']['image_show'] | |
points = global_state["points"] | |
point_idx = get_latest_points_pair(points) | |
if point_idx is None: | |
points[0] = {'start': xy, 'target': None} | |
print(f'Click Image - Start - {xy}') | |
elif points[point_idx].get('target', None) is None: | |
points[point_idx]['target'] = xy | |
print(f'Click Image - Target - {xy}') | |
else: | |
points[point_idx + 1] = {'start': xy, 'target': None} | |
print(f'Click Image - Start - {xy}') | |
image_raw = global_state['images']['image_raw'] | |
image_draw = update_image_draw( | |
image_raw, | |
global_state['points'], | |
global_state['mask'], | |
global_state['show_mask'], | |
global_state, | |
) | |
return global_state, image_draw | |
form_image.select( | |
on_click_image, | |
inputs=[global_state], | |
outputs=[global_state, form_image], | |
queue=False, | |
) | |
def on_click_clear_points(global_state): | |
"""Function to handle clear all control points | |
1. clear global_state['points'] (clear_state) | |
2. re-init network | |
2. re-draw image | |
""" | |
clear_state(global_state, target='point') | |
renderer: Renderer = global_state["renderer"] | |
renderer.feat_refs = None | |
image_raw = global_state['images']['image_raw'] | |
image_draw = update_image_draw(image_raw, {}, global_state['mask'], | |
global_state['show_mask'], global_state) | |
return global_state, image_draw | |
undo_points.click(on_click_clear_points, | |
inputs=[global_state], | |
outputs=[global_state, form_image], | |
queue=False) | |
def on_click_show_mask(global_state, show_mask): | |
"""Function to control whether show mask on image.""" | |
global_state['show_mask'] = show_mask | |
image_raw = global_state['images']['image_raw'] | |
image_draw = update_image_draw( | |
image_raw, | |
global_state['points'], | |
global_state['mask'], | |
global_state['show_mask'], | |
global_state, | |
) | |
return global_state, image_draw | |
show_mask.change( | |
on_click_show_mask, | |
inputs=[global_state, show_mask], | |
outputs=[global_state, form_image], | |
queue=False, | |
) | |
print("SHAReD: Start app", parser.parse_args()) | |
gr.close_all() | |
app.queue(concurrency_count=1, max_size=200, api_open=False) | |
app.launch(share=args.share, show_api=False) | |