import argparse import gradio as gr import os import shutil from glob import glob from PIL import Image import numpy as np import matplotlib.pyplot as plt from torchvision.utils import make_grid, save_image from torchvision.io import read_image import torchvision.transforms.functional as F from functools import partial from datetime import datetime plt.rcParams["savefig.bbox"] = 'tight' def show(imgs): if not isinstance(imgs, list): imgs = [imgs] fig, axs = plt.subplots(ncols=len(imgs), squeeze=False) for i, img in enumerate(imgs): img = F.to_pil_image(img.detach()) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) class Intermediate: def __init__(self): self.input_img = None self.input_img_time = 0 model_ckpts = {"elf": "ffhq-elf.pkl", "greek_statue": "ffhq-greek_statue.pkl", "hobbit": "ffhq-hobbit.pkl", "lego": "ffhq-lego.pkl", "masquerade": "ffhq-masquerade.pkl", "neanderthal": "ffhq-neanderthal.pkl", "orc": "ffhq-orc.pkl", "pixar": "ffhq-pixar.pkl", "skeleton": "ffhq-skeleton.pkl", "stone_golem": "ffhq-stone_golem.pkl", "super_mario": "ffhq-super_mario.pkl", "tekken": "ffhq-tekken.pkl", "yoda": "ffhq-yoda.pkl", "zombie": "ffhq-zombie.pkl", "cat_in_Zootopia": "cat-cat_in_Zootopia.pkl", "fox_in_Zootopia": "cat-fox_in_Zootopia.pkl", "golden_aluminum_animal": "cat-golden_aluminum_animal.pkl", } manip_model_ckpts = {"super_mario": "ffhq-super_mario.pkl", "lego": "ffhq-lego.pkl", "neanderthal": "ffhq-neanderthal.pkl", "orc": "ffhq-orc.pkl", "pixar": "ffhq-pixar.pkl", "skeleton": "ffhq-skeleton.pkl", "stone_golem": "ffhq-stone_golem.pkl", "tekken": "ffhq-tekken.pkl", "greek_statue": "ffhq-greek_statue.pkl", "yoda": "ffhq-yoda.pkl", "zombie": "ffhq-zombie.pkl", "elf": "ffhq-elf.pkl", } def TextGuidedImageTo3D(intermediate, img, model_name, num_inversion_steps, truncation): if img != intermediate.input_img: if os.path.exists('input_imgs_gradio'): shutil.rmtree('input_imgs_gradio') os.makedirs('input_imgs_gradio', exist_ok=True) img.save('input_imgs_gradio/input.png') intermediate.input_img = img now = datetime.now() intermediate.input_img_time = now.strftime('%Y-%m-%d_%H:%M:%S') all_model_names = manip_model_ckpts.keys() generator_type = 'ffhq' if model_name == 'all': _no_video_models = [] for _model_name in all_model_names: if not os.path.exists(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[_model_name]}__input_inv.mp4'): print() _no_video_models.append(_model_name) model_names_command = '' for _model_name in _no_video_models: if not os.path.exists(f'finetuned/{model_ckpts[_model_name]}'): command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[_model_name]} -O finetuned/{model_ckpts[_model_name]} """ os.system(command) model_names_command += f"finetuned/{model_ckpts[_model_name]} " w_pths = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/3_inversion_result/*.pt')) if len(w_pths) == 0: mode = 'manip' else: mode = 'manip_from_inv' if len(_no_video_models) > 0: command = f"""python datid3d_test.py --mode {mode} \ --indir='input_imgs_gradio' \ --generator_type={generator_type} \ --outdir='test_runs' \ --trunc={truncation} \ --network {model_names_command} \ --num_inv_steps={num_inversion_steps} \ --down_src_eg3d_from_nvidia=False \ --name_tag='_gradio_{intermediate.input_img_time}' \ --shape=False \ --w_frames 60 """ print(command) os.system(command) aligned_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/2_pose_result/*.png'))[0] aligned_img = Image.open(aligned_img_pth) result_imgs = [] for _model_name in all_model_names: img_pth = f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[_model_name]}__input_inv.png' result_imgs.append(read_image(img_pth)) result_grid_pt = make_grid(result_imgs, nrow=1) result_img = F.to_pil_image(result_grid_pt) else: if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} """ os.system(command) if not os.path.exists(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[model_name]}__input_inv.mp4'): w_pths = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/3_inversion_result/*.pt')) if len(w_pths) == 0: mode = 'manip' else: mode = 'manip_from_inv' command = f"""python datid3d_test.py --mode {mode} \ --indir='input_imgs_gradio' \ --generator_type={generator_type} \ --outdir='test_runs' \ --trunc={truncation} \ --network finetuned/{model_ckpts[model_name]} \ --num_inv_steps={num_inversion_steps} \ --down_src_eg3d_from_nvidia=0 \ --name_tag='_gradio_{intermediate.input_img_time}' \ --shape=False --w_frames 60""" print(command) os.system(command) aligned_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/2_pose_result/*.png'))[0] aligned_img = Image.open(aligned_img_pth) result_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/*{model_ckpts[model_name]}*.png'))[0] result_img = Image.open(result_img_pth) if model_name=='all': result_video_pth = f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___ffhq-all__input_inv.mp4' if os.path.exists(result_video_pth): os.remove(result_video_pth) command = 'ffmpeg ' for _model_name in all_model_names: command += f'-i test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___ffhq-{_model_name}.pkl__input_inv.mp4 ' # command += '-filter_complex "[0:v]scale=2*iw:-1[v0];[1:v]scale=2*iw:-1[v1];[2:v]scale=2*iw:-1[v2];[3:v]scale=2*iw:-1[v3];[4:v]scale=2*iw:-1[v4];[5:v]scale=2*iw:-1[v5];[6:v]scale=2*iw:-1[v6];[7:v]scale=2*iw:-1[v7];[8:v]scale=2*iw:-1[v8];[9:v]scale=2*iw:-1[v9];[10:v]scale=2*iw:-1[v10];[11:v]scale=2*iw:-1[v11];[v0][v1][v2][v3][v4][v5][v6][v7][v8][v9][v10][v11]xstack=inputs=12:layout=0_0|w0_0|w0+w1_0|w0+w1+w2_0|0_h0|w4_h0|w4+w5_h0|w4+w5+w6_h0|0_h0+h4|w8_h0+h4|w8+w9_h0+h4|w8+w9+w10_h0+h4" ' command += '-filter_complex "[v0][v1][v2][v3][v4][v5][v6][v7][v8][v9][v10][v11]concat=n=12:v=1:a=0[output]"' command += f" -vcodec libx264 {result_video_pth}" print() print(command) os.system(command) else: result_video_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/*{model_ckpts[model_name]}*.mp4'))[0] return aligned_img, result_img, result_video_pth def SampleImage(model_name, num_samples, truncation, seed): seed_list = np.random.RandomState(seed).choice(np.arange(10000), num_samples).tolist() seeds = '' for seed in seed_list: seeds += f'{seed},' seeds = seeds[:-1] if model_name in ["fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"]: generator_type = 'cat' else: generator_type = 'ffhq' if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} """ os.system(command) command = f"""python datid3d_test.py --mode image \ --generator_type={generator_type} \ --outdir='test_runs' \ --seeds={seeds} \ --trunc={truncation} \ --network=finetuned/{model_ckpts[model_name]} \ --shape=False""" print(command) os.system(command) result_img_pths = sorted(glob(f'test_runs/image/*{model_ckpts[model_name]}*.png')) result_imgs = [] for img_pth in result_img_pths: result_imgs.append(read_image(img_pth)) result_grid_pt = make_grid(result_imgs, nrow=1) result_grid_pil = F.to_pil_image(result_grid_pt) return result_grid_pil def SampleVideo(model_name, grid_height, truncation, seed): seed_list = np.random.RandomState(seed).choice(np.arange(10000), grid_height**2).tolist() seeds = '' for seed in seed_list: seeds += f'{seed},' seeds = seeds[:-1] if model_name in ["fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"]: generator_type = 'cat' else: generator_type = 'ffhq' if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} """ os.system(command) command = f"""python datid3d_test.py --mode video \ --generator_type={generator_type} \ --outdir='test_runs' \ --seeds={seeds} \ --trunc={truncation} \ --grid={grid_height}x{grid_height} \ --network=finetuned/{model_ckpts[model_name]} \ --shape=False""" print(command) os.system(command) result_video_pth = sorted(glob(f'test_runs/video/*{model_ckpts[model_name]}*.mp4'))[0] return result_video_pth if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--share', action='store_true', help="public url") args = parser.parse_args() demo = gr.Blocks(title="DATID-3D Interactive Demo") os.makedirs('finetuned', exist_ok=True) intermediate = Intermediate() with demo: gr.Markdown("# DATID-3D Interactive Demo") gr.Markdown( "### Demo of the CVPR 2023 paper \"DATID-3D: Diversity-Preserved Domain Adaptation Using Text-to-Image Diffusion for 3D Generative Model\"") with gr.Tab("Text-guided Manipulated 3D reconstruction"): gr.Markdown("Text-guided Image-to-3D Translation") with gr.Row(): with gr.Column(scale=1, variant='panel'): t_image_input = gr.Image(source='upload', type="pil", interactive=True) t_model_name = gr.Radio(["super_mario", "lego", "neanderthal", "orc", "pixar", "skeleton", "stone_golem","tekken", "greek_statue", "yoda", "zombie", "elf", "all"], label="Model fine-tuned through DATID-3D", value="super_mario", interactive=True) with gr.Accordion("Advanced Options", open=False): t_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, value=0.8) t_num_inversion_steps = gr.Slider(200, 1000, value=200, step=1, label='Number of steps for the invresion') with gr.Row(): t_button_gen_result = gr.Button("Generate Result", variant='primary') # t_button_gen_video = gr.Button("Generate Video", variant='primary') # t_button_gen_image = gr.Button("Generate Image", variant='secondary') with gr.Row(): t_align_image_result = gr.Image(label="Alignment result", interactive=False) with gr.Column(scale=1, variant='panel'): with gr.Row(): t_video_result = gr.Video(label="Video result", interactive=False) with gr.Row(): t_image_result = gr.Image(label="Image result", interactive=False) with gr.Tab("Sample Images"): with gr.Row(): with gr.Column(scale=1, variant='panel'): i_model_name = gr.Radio( ["elf", "greek_statue", "hobbit", "lego", "masquerade", "neanderthal", "orc", "pixar", "skeleton", "stone_golem", "super_mario", "tekken", "yoda", "zombie", "fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"], label="Model fine-tuned through DATID-3D", value="super_mario", interactive=True) i_num_samples = gr.Slider(0, 20, value=4, step=1, label='Number of samples') i_seed = gr.Slider(label="Seed", minimum=0, maximum=1000000000, step=1, value=1235) with gr.Accordion("Advanced Options", open=False): i_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, value=0.8) with gr.Row(): i_button_gen_image = gr.Button("Generate Image", variant='primary') with gr.Column(scale=1, variant='panel'): with gr.Row(): i_image_result = gr.Image(label="Image result", interactive=False) with gr.Tab("Sample Videos"): with gr.Row(): with gr.Column(scale=1, variant='panel'): v_model_name = gr.Radio( ["elf", "greek_statue", "hobbit", "lego", "masquerade", "neanderthal", "orc", "pixar", "skeleton", "stone_golem", "super_mario", "tekken", "yoda", "zombie", "fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"], label="Model fine-tuned through DATID-3D", value="super_mario", interactive=True) v_grid_height = gr.Slider(0, 5, value=2, step=1,label='Height of the grid') v_seed = gr.Slider(label="Seed", minimum=0, maximum=1000000000, step=1, value=1235) with gr.Accordion("Advanced Options", open=False): v_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, value=0.8) with gr.Row(): v_button_gen_video = gr.Button("Generate Video", variant='primary') with gr.Column(scale=1, variant='panel'): with gr.Row(): v_video_result = gr.Video(label="Video result", interactive=False) # functions t_button_gen_result.click(fn=partial(TextGuidedImageTo3D, intermediate), inputs=[t_image_input, t_model_name, t_num_inversion_steps, t_truncation], outputs=[t_align_image_result, t_image_result, t_video_result]) i_button_gen_image.click(fn=SampleImage, inputs=[i_model_name, i_num_samples, i_truncation, i_seed], outputs=[i_image_result]) v_button_gen_video.click(fn=SampleVideo, inputs=[i_model_name, v_grid_height, v_truncation, v_seed], outputs=[v_video_result]) demo.queue(concurrency_count=1) demo.launch(share=args.share)