DragGAN / gradio_app.py
fffiloni's picture
Update gradio_app.py
55ee553
import os
import gradio as gr
import torch
import numpy as np
import imageio
from PIL import Image
import uuid
from drag_gan import drag_gan, stylegan2
device = 'cuda'
SIZE_TO_CLICK_SIZE = {
512: 5,
256: 2
}
CKPT_SIZE = {
'stylegan2-ffhq-config-f.pt': 512,
'stylegan2-cat-config-f.pt': 256,
'stylegan2-church-config-f.pt': 256,
'stylegan2-horse-config-f.pt': 256,
}
class ImageMask(gr.components.Image):
"""
Sets: source="canvas", tool="sketch"
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
def preprocess(self, x):
if x is None:
return x
if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
decode_image = gr.processing_utils.decode_base64_to_image(x)
width, height = decode_image.size
mask = np.zeros((height, width, 4), dtype=np.uint8)
mask[..., -1] = 255
mask = self.postprocess(mask)
x = {'image': x, 'mask': mask}
return super().preprocess(x)
class ModelWrapper:
def __init__(self, **kwargs):
self.g_ema = stylegan2(**kwargs).to(device)
def to_image(tensor):
tensor = tensor.squeeze(0).permute(1, 2, 0)
arr = tensor.detach().cpu().numpy()
arr = (arr - arr.min()) / (arr.max() - arr.min())
arr = arr * 255
return arr.astype('uint8')
def add_points_to_image(image, points, size=5):
h, w, = image.shape[:2]
for x, y in points['target']:
image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [255, 0, 0]
for x, y in points['handle']:
image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [0, 0, 255]
return image
def on_click(image, target_point, points, size, evt: gr.SelectData):
if target_point:
points['target'].append([evt.index[1], evt.index[0]])
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
return image, str(evt.index), not target_point
points['handle'].append([evt.index[1], evt.index[0]])
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
return image, str(evt.index), not target_point
def on_drag(model, points, max_iters, state, size, mask):
if len(points['handle']) == 0:
raise gr.Error('You must select at least one handle point and target point.')
if len(points['handle']) != len(points['target']):
raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
max_iters = int(max_iters)
latent = state['latent']
noise = state['noise']
F = state['F']
handle_points = [torch.tensor(p).float() for p in points['handle']]
target_points = [torch.tensor(p).float() for p in points['target']]
mask = Image.fromarray(mask['mask']).convert('L')
mask = np.array(mask) == 255
mask = torch.from_numpy(mask).float().to(device)
mask = mask.unsqueeze(0).unsqueeze(0)
step = 0
for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
handle_points, target_points, mask,
max_iters=max_iters):
image = to_image(sample2)
state['F'] = F
state['latent'] = latent
state['sample'] = sample2
points['handle'] = [p.cpu().numpy().astype('int') for p in handle_points]
add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
state['history'].append(image)
step += 1
yield image, state, step
def on_reset(points, image, state):
return {'target': [], 'handle': []}, to_image(state['sample'])
def on_undo(points, image, state, size):
image = to_image(state['sample'])
if len(points['target']) < len(points['handle']):
points['handle'] = points['handle'][:-1]
else:
points['handle'] = points['handle'][:-1]
points['target'] = points['target'][:-1]
add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
return points, image
def on_change_model(selected, model):
size = CKPT_SIZE[selected]
model = ModelWrapper(size=size, ckpt=selected)
g_ema = model.g_ema
sample_z = torch.randn([1, 512], device=device)
latent, noise = g_ema.prepare([sample_z])
sample, F = g_ema.generate(latent, noise)
state = {
'latent': latent,
'noise': noise,
'F': F,
'sample': sample,
'history': []
}
return model, state, to_image(sample), size
def on_new_image(model):
g_ema = model.g_ema
sample_z = torch.randn([1, 512], device=device)
latent, noise = g_ema.prepare([sample_z])
sample, F = g_ema.generate(latent, noise)
state = {
'latent': latent,
'noise': noise,
'F': F,
'sample': sample,
'history': []
}
points = {'target': [], 'handle': []}
target_point = False
return to_image(sample), to_image(sample), state, points, target_point
def on_max_iter_change(max_iters):
return gr.update(maximum=max_iters)
def on_save_files(image, state):
os.makedirs('tmp', exist_ok=True)
image_name = f'tmp/image_{uuid.uuid4()}.png'
video_name = f'tmp/video_{uuid.uuid4()}.mp4'
imageio.imsave(image_name, image)
imageio.mimsave(video_name, state['history'])
return [image_name, video_name]
def on_show_save():
return gr.update(visible=True)
def main():
torch.cuda.manual_seed(25)
with gr.Blocks() as demo:
wrapped_model = ModelWrapper()
model = gr.State(wrapped_model)
sample_z = torch.randn([1, 512], device=device)
latent, noise = wrapped_model.g_ema.prepare([sample_z])
sample, F = wrapped_model.g_ema.generate(latent, noise)
gr.Markdown(
"""
# DragGAN (Unofficial)
Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
[Github](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
## Tutorial
1. (Optional) Draw a mask indicate the movable region.
2. Setup a least one pair of handle point and target point.
3. Click "Drag it".
""",
)
state = gr.State({
'latent': latent,
'noise': noise,
'F': F,
'sample': sample,
'history': []
})
points = gr.State({'target': [], 'handle': []})
size = gr.State(512)
with gr.Row():
with gr.Column(scale=0.3):
with gr.Accordion("Model"):
model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value='stylegan2-ffhq-config-f.pt',
label='StyleGAN2 model')
max_iters = gr.Slider(1, 20, 20, step=1, label='Max Iterations')
new_btn = gr.Button('New Image')
with gr.Accordion('Drag'):
with gr.Row():
with gr.Column(min_width=100):
text = gr.Textbox(label='Selected Point', interactive=False)
with gr.Column(min_width=100):
target_point = gr.Checkbox(label='Target Point', interactive=False)
with gr.Row():
with gr.Column(min_width=100):
reset_btn = gr.Button('Reset All')
with gr.Column(min_width=100):
undo_btn = gr.Button('Undo Last')
with gr.Row():
btn = gr.Button('Drag it', variant='primary')
with gr.Accordion('Save', visible=False) as save_panel:
files = gr.Files(value=[])
progress = gr.Slider(value=0, maximum=20, label='Progress', interactive=False)
with gr.Column():
with gr.Tabs():
with gr.Tab('Draw a Mask', id='mask'):
mask = gr.ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768)
with gr.Tab('Setup Handle Points', id='input'):
image = gr.Image(to_image(sample)).style(height=768, width=768)
image.select(on_click, [image, target_point, points, size], [image, text, target_point], queue=False)
btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress])#.then(
#on_show_save, outputs=save_panel)#.then(
#on_save_files, inputs=[image, state], outputs=[files]
# )
reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image],queue=False)
undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image], queue=False)
model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, size], queue=False)
new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point], queue=False)
max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress, queue=False)
return demo
if __name__ == '__main__':
import fire
demo = main()
fire.Fire(demo.queue(concurrency_count=1, max_size=20).launch)