File size: 4,950 Bytes
f12ab4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import argparse

### Parameters
parser = argparse.ArgumentParser()

# For all
parser.add_argument('--mode', type=str, required=True, choices=['pdg', 'ft', 'both'],
                    help="pdg: Pose-aware dataset generation, ft: Fine-tuning 3D generative models, both: Doing both")
parser.add_argument('--down_src_eg3d_from_nvidia', default=True)
# Pose-aware dataset generation
parser.add_argument('--pdg_prompt', type=str, required=True)
parser.add_argument('--pdg_generator_type', default='ffhq', type=str, choices=['ffhq', 'cat'])  # ffhq, cat
parser.add_argument('--pdg_strength', default=0.7, type=float)
parser.add_argument('--pdg_guidance_scale', default=8, type=float)
parser.add_argument('--pdg_num_images', default=1000, type=int)
parser.add_argument('--pdg_sd_model_id', default='stabilityai/stable-diffusion-2-1-base', type=str)
parser.add_argument('--pdg_num_inference_steps', default=50, type=int)
parser.add_argument('--pdg_name_tag', default='', type=str)
parser.add_argument('--down_src_eg3d_from_nvidia', default=True)
# Fine-tuning 3D generative models
parser.add_argument('--ft_generator_type', default='same', help="None: The same type as pdg_generator_type", type=str, choices=['ffhq', 'cat', 'same'])
parser.add_argument('--ft_kimg', default=200, type=int)
parser.add_argument('--ft_batch', default=20, type=int)
parser.add_argument('--ft_tick', default=1, type=int)
parser.add_argument('--ft_snap', default=50, type=int)
parser.add_argument('--ft_outdir', default='../training_runs', type=str) #
parser.add_argument('--ft_gpus', default=1, type=str) #
parser.add_argument('--ft_workers', default=8, type=int) #
parser.add_argument('--ft_data_max_size', default=500000000, type=int) #
parser.add_argument('--ft_freeze_dec_sr', default=True, type=bool) #

args = parser.parse_args()


### Pose-aware target generation
if args.mode in ['pdg', 'both']:
    os.chdir('eg3d')
    if args.pdg_generator_type == 'cat':
        pdg_generator_id = 'afhqcats512-128.pkl'
    else:
        pdg_generator_id = 'ffhqrebalanced512-128.pkl'

    pdg_generator_path = f'pretrained/{pdg_generator_id}'
    if not os.path.exists(pdg_generator_path):
        os.makedirs(f'pretrained', exist_ok=True)
        print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.")
        if args.down_src_eg3d_from_nvidia == True:
            os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{pdg_generator_id} -O {pdg_generator_path}')
        else:
            os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{pdg_generator_id} -O {pdg_generator_path}')
    command = f"""python datid3d_data_gen.py  \
   --prompt="{args.pdg_prompt}" \
   --data_type={args.pdg_generator_type} \
   --strength={args.pdg_strength} \
   --guidance_scale={args.pdg_guidance_scale} \
   --num_images={args.pdg_num_images} \
   --sd_model_id="{args.pdg_sd_model_id}" \
   --num_inference_steps={args.pdg_num_inference_steps} \
   --name_tag={args.pdg_name_tag} 
    """
    print(f"{command} \n")
    os.system(command)
    os.chdir('..')

### Filtering process
# TODO


### Fine-tuning 3D generative models
if args.mode in ['ft', 'both']:
    os.chdir('eg3d')
    if args.ft_generator_type == 'same':
        args.ft_generator_type = args.pdg_generator_type

    if args.ft_generator_type == 'cat':
        ft_generator_id = 'afhqcats512-128.pkl'
    else:
        ft_generator_id = 'ffhqrebalanced512-128.pkl'

    ft_generator_path = f'pretrained/{ft_generator_id}'
    if not os.path.exists(ft_generator_path):
        os.makedirs(f'pretrained', exist_ok=True)
        print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.")
        if args.down_src_eg3d_from_nvidia == True:
            os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{ft_generator_id} -O {ft_generator_path}')
        else:
            os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{ft_generator_id} -O {ft_generator_path}')

    dataset_id = f'data_{args.pdg_generator_type}_{args.pdg_prompt.replace(" ", "_")}{args.pdg_name_tag}'
    dataset_path = f'./exp_data/{dataset_id}/{dataset_id}.zip'


    command = f"""python train.py  \
    --outdir={args.ft_outdir} \
    --cfg={args.ft_generator_type} \
    --data="{dataset_path}"   \
    --resume={ft_generator_path} --freeze_dec_sr={args.ft_freeze_dec_sr} \
    --batch={args.ft_batch} --workers={args.ft_workers} --gpus={args.ft_gpus} \
    --tick={args.ft_tick} --snap={args.ft_snap} --data_max_size={args.ft_data_max_size} --kimg={args.ft_kimg} \
    --gamma=5 --aug=ada --neural_rendering_resolution_final=128 --gen_pose_cond=True --gpc_reg_prob=0.8 --metrics=None 
    """
    print(f"{command} \n")
    os.system(command)
    os.chdir('..')