import os import gradio as gr import torch import numpy as np import imageio from PIL import Image import uuid from drag_gan import drag_gan, stylegan2 from stylegan2.inversion import inverse_image device = 'cuda' SIZE_TO_CLICK_SIZE = { 1024: 8, 512: 5, 256: 2 } CKPT_SIZE = { 'stylegan2-ffhq-config-f.pt': 1024, 'stylegan2-cat-config-f.pt': 256, 'stylegan2-church-config-f.pt': 256, 'stylegan2-horse-config-f.pt': 256, 'ada/ffhq.pt': 1024, 'ada/afhqcat.pt': 512, 'ada/afhqdog.pt': 512, 'ada/afhqwild.pt': 512, 'ada/brecahad.pt': 512, 'ada/metfaces.pt': 512, } DEFAULT_CKPT = 'stylegan2-ffhq-config-f.pt' class grImage(gr.components.Image): is_template = True def preprocess(self, x): if x is None: return x if self.tool == "sketch" and self.source in ["upload", "webcam"]: decode_image = gr.processing_utils.decode_base64_to_image(x) width, height = decode_image.size mask = np.zeros((height, width, 4), dtype=np.uint8) mask[..., -1] = 255 mask = self.postprocess(mask) x = {'image': x, 'mask': mask} return super().preprocess(x) class ImageMask(gr.components.Image): """ Sets: source="canvas", tool="sketch" """ is_template = True def __init__(self, **kwargs): super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) def preprocess(self, x): if x is None: return x if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict: decode_image = gr.processing_utils.decode_base64_to_image(x) width, height = decode_image.size mask = np.zeros((height, width, 4), dtype=np.uint8) mask[..., -1] = 255 mask = self.postprocess(mask) x = {'image': x, 'mask': mask} return super().preprocess(x) class ModelWrapper: def __init__(self, **kwargs): self.g_ema = stylegan2(**kwargs).to(device) def to_image(tensor): tensor = tensor.squeeze(0).permute(1, 2, 0) arr = tensor.detach().cpu().numpy() arr = (arr - arr.min()) / (arr.max() - arr.min()) arr = arr * 255 return arr.astype('uint8') def add_points_to_image(image, points, size=5): h, w, = image.shape[:2] for x, y in points['target']: image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [255, 0, 0] for x, y in points['handle']: image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [0, 0, 255] return image def on_click(image, target_point, points, size, evt: gr.SelectData): if target_point: points['target'].append([evt.index[1], evt.index[0]]) image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size]) return image, str(evt.index), not target_point points['handle'].append([evt.index[1], evt.index[0]]) image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size]) return image, str(evt.index), not target_point def on_drag(model, points, max_iters, state, size, mask): if len(points['handle']) == 0: raise gr.Error('You must select at least one handle point and target point.') if len(points['handle']) != len(points['target']): raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.') max_iters = int(max_iters) latent = state['latent'] noise = state['noise'] F = state['F'] handle_points = [torch.tensor(p).float() for p in points['handle']] target_points = [torch.tensor(p).float() for p in points['target']] if mask.get('mask') is not None: mask = Image.fromarray(mask['mask']).convert('L') mask = np.array(mask) == 255 mask = torch.from_numpy(mask).float().to(device) mask = mask.unsqueeze(0).unsqueeze(0) else: mask = None step = 0 for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F, handle_points, target_points, mask, max_iters=max_iters): image = to_image(sample2) state['F'] = F state['latent'] = latent state['sample'] = sample2 points['handle'] = [p.cpu().numpy().astype('int') for p in handle_points] add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size]) state['history'].append(image) step += 1 yield image, state, step def on_reset(points, image, state): return {'target': [], 'handle': []}, to_image(state['sample']) def on_undo(points, image, state, size): image = to_image(state['sample']) if len(points['target']) < len(points['handle']): points['handle'] = points['handle'][:-1] else: points['handle'] = points['handle'][:-1] points['target'] = points['target'][:-1] add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size]) return points, image def on_change_model(selected, model): size = CKPT_SIZE[selected] model = ModelWrapper(size=size, ckpt=selected) g_ema = model.g_ema sample_z = torch.randn([1, 512], device=device) latent, noise = g_ema.prepare([sample_z]) sample, F = g_ema.generate(latent, noise) state = { 'latent': latent, 'noise': noise, 'F': F, 'sample': sample, 'history': [] } return model, state, to_image(sample), to_image(sample), size def on_new_image(model): g_ema = model.g_ema sample_z = torch.randn([1, 512], device=device) latent, noise = g_ema.prepare([sample_z]) sample, F = g_ema.generate(latent, noise) state = { 'latent': latent, 'noise': noise, 'F': F, 'sample': sample, 'history': [] } points = {'target': [], 'handle': []} target_point = False return to_image(sample), to_image(sample), state, points, target_point def on_max_iter_change(max_iters): return gr.update(maximum=max_iters) def on_save_files(image, state): os.makedirs('tmp', exist_ok=True) image_name = f'tmp/image_{uuid.uuid4()}.png' video_name = f'tmp/video_{uuid.uuid4()}.mp4' imageio.imsave(image_name, image) imageio.mimsave(video_name, state['history']) return [image_name, video_name] def on_show_save(): return gr.update(visible=True) def on_image_change(model, image_size, image): image = Image.fromarray(image) result = inverse_image( model.g_ema, image, image_size=image_size ) result['history'] = [] image = to_image(result['sample']) points = {'target': [], 'handle': []} target_point = False return image, image, result, points, target_point def on_mask_change(mask): return mask['image'] def main(): torch.cuda.manual_seed(25) with gr.Blocks() as demo: wrapped_model = ModelWrapper(ckpt=DEFAULT_CKPT, size=CKPT_SIZE[DEFAULT_CKPT]) model = gr.State(wrapped_model) sample_z = torch.randn([1, 512], device=device) latent, noise = wrapped_model.g_ema.prepare([sample_z]) sample, F = wrapped_model.g_ema.generate(latent, noise) gr.Markdown( """ # DragGAN Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/) [Our Implementation](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet) ## Tutorial 1. (Optional) Draw a mask indicate the movable region. 2. Setup a least one pair of handle point and target point. 3. Click "Drag it". ## Hints - Handle points (Blue): the point you want to drag. - Target points (Red): the destination you want to drag towards to. ## Primary Support of Custom Image. - We now support dragging user uploaded image by GAN inversion. - **Please upload your image at `Setup Handle Points` pannel.** Upload it from `Draw a Mask` would cause errors for now. - Due to the limitation of GAN inversion, - You might wait roughly 1 minute to see the GAN version of the uploaded image. - The shown image might be slightly difference from the uploaded one. - It could also fail to invert the uploaded image and generate very poor results. - Idealy, you should choose the closest model of the uploaded image. For example, choose `stylegan2-ffhq-config-f.pt` for human face. `stylegan2-cat-config-f.pt` for cat. > Please fire an issue if you have encounted any problem. Also don't forgot to give a star to the [Official Repo](https://github.com/XingangPan/DragGAN), [our project](https://github.com/Zeqiang-Lai/DragGAN) could not exist without it. """, ) state = gr.State({ 'latent': latent, 'noise': noise, 'F': F, 'sample': sample, 'history': [] }) points = gr.State({'target': [], 'handle': []}) size = gr.State(CKPT_SIZE[DEFAULT_CKPT]) with gr.Row(): with gr.Column(scale=0.3): with gr.Accordion("Model"): model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value=DEFAULT_CKPT, label='StyleGAN2 model') max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations') new_btn = gr.Button('New Image') with gr.Accordion('Drag'): with gr.Row(): with gr.Column(min_width=100): text = gr.Textbox(label='Selected Point', interactive=False) with gr.Column(min_width=100): target_point = gr.Checkbox(label='Target Point', interactive=False) with gr.Row(): with gr.Column(min_width=100): reset_btn = gr.Button('Reset All') with gr.Column(min_width=100): undo_btn = gr.Button('Undo Last') with gr.Row(): btn = gr.Button('Drag it', variant='primary') with gr.Accordion('Save', visible=False) as save_panel: files = gr.Files(value=[]) progress = gr.Slider(value=0, maximum=20, label='Progress', interactive=False) with gr.Column(): with gr.Tabs(): with gr.Tab('Draw a Mask', id='mask'): mask = ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768) with gr.Tab('Setup Handle Points', id='input'): image = grImage(to_image(sample)).style(height=768, width=768) image.select(on_click, [image, target_point, points, size], [image, text, target_point]) image.upload(on_image_change, [model, size, image], [image, mask, state, points, target_point]) mask.upload(on_mask_change, [mask], [image]) btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then( on_show_save, outputs=save_panel).then( on_save_files, inputs=[image, state], outputs=[files] ) reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image]) undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image]) model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, mask, size]) new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point]) max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress) return demo if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--device', default='cpu') parser.add_argument('--share', action='store_true') parser.add_argument('-p', '--port', default=None) parser.add_argument('--ip', default=None) args = parser.parse_args() device = args.device demo = main() print('Successfully loaded, starting gradio demo') demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)