DATID-3D / datid3d_train.py
gwang-kim's picture
u
f12ab4c
import os
import argparse
### Parameters
parser = argparse.ArgumentParser()
# For all
parser.add_argument('--mode', type=str, required=True, choices=['pdg', 'ft', 'both'],
help="pdg: Pose-aware dataset generation, ft: Fine-tuning 3D generative models, both: Doing both")
parser.add_argument('--down_src_eg3d_from_nvidia', default=True)
# Pose-aware dataset generation
parser.add_argument('--pdg_prompt', type=str, required=True)
parser.add_argument('--pdg_generator_type', default='ffhq', type=str, choices=['ffhq', 'cat']) # ffhq, cat
parser.add_argument('--pdg_strength', default=0.7, type=float)
parser.add_argument('--pdg_guidance_scale', default=8, type=float)
parser.add_argument('--pdg_num_images', default=1000, type=int)
parser.add_argument('--pdg_sd_model_id', default='stabilityai/stable-diffusion-2-1-base', type=str)
parser.add_argument('--pdg_num_inference_steps', default=50, type=int)
parser.add_argument('--pdg_name_tag', default='', type=str)
parser.add_argument('--down_src_eg3d_from_nvidia', default=True)
# Fine-tuning 3D generative models
parser.add_argument('--ft_generator_type', default='same', help="None: The same type as pdg_generator_type", type=str, choices=['ffhq', 'cat', 'same'])
parser.add_argument('--ft_kimg', default=200, type=int)
parser.add_argument('--ft_batch', default=20, type=int)
parser.add_argument('--ft_tick', default=1, type=int)
parser.add_argument('--ft_snap', default=50, type=int)
parser.add_argument('--ft_outdir', default='../training_runs', type=str) #
parser.add_argument('--ft_gpus', default=1, type=str) #
parser.add_argument('--ft_workers', default=8, type=int) #
parser.add_argument('--ft_data_max_size', default=500000000, type=int) #
parser.add_argument('--ft_freeze_dec_sr', default=True, type=bool) #
args = parser.parse_args()
### Pose-aware target generation
if args.mode in ['pdg', 'both']:
os.chdir('eg3d')
if args.pdg_generator_type == 'cat':
pdg_generator_id = 'afhqcats512-128.pkl'
else:
pdg_generator_id = 'ffhqrebalanced512-128.pkl'
pdg_generator_path = f'pretrained/{pdg_generator_id}'
if not os.path.exists(pdg_generator_path):
os.makedirs(f'pretrained', exist_ok=True)
print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.")
if args.down_src_eg3d_from_nvidia == True:
os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{pdg_generator_id} -O {pdg_generator_path}')
else:
os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{pdg_generator_id} -O {pdg_generator_path}')
command = f"""python datid3d_data_gen.py \
--prompt="{args.pdg_prompt}" \
--data_type={args.pdg_generator_type} \
--strength={args.pdg_strength} \
--guidance_scale={args.pdg_guidance_scale} \
--num_images={args.pdg_num_images} \
--sd_model_id="{args.pdg_sd_model_id}" \
--num_inference_steps={args.pdg_num_inference_steps} \
--name_tag={args.pdg_name_tag}
"""
print(f"{command} \n")
os.system(command)
os.chdir('..')
### Filtering process
# TODO
### Fine-tuning 3D generative models
if args.mode in ['ft', 'both']:
os.chdir('eg3d')
if args.ft_generator_type == 'same':
args.ft_generator_type = args.pdg_generator_type
if args.ft_generator_type == 'cat':
ft_generator_id = 'afhqcats512-128.pkl'
else:
ft_generator_id = 'ffhqrebalanced512-128.pkl'
ft_generator_path = f'pretrained/{ft_generator_id}'
if not os.path.exists(ft_generator_path):
os.makedirs(f'pretrained', exist_ok=True)
print("Pretrained EG3D model cannot be found. Downloading the pretrained EG3D models.")
if args.down_src_eg3d_from_nvidia == True:
os.system(f'wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/{ft_generator_id} -O {ft_generator_path}')
else:
os.system(f'wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/nvidia_{ft_generator_id} -O {ft_generator_path}')
dataset_id = f'data_{args.pdg_generator_type}_{args.pdg_prompt.replace(" ", "_")}{args.pdg_name_tag}'
dataset_path = f'./exp_data/{dataset_id}/{dataset_id}.zip'
command = f"""python train.py \
--outdir={args.ft_outdir} \
--cfg={args.ft_generator_type} \
--data="{dataset_path}" \
--resume={ft_generator_path} --freeze_dec_sr={args.ft_freeze_dec_sr} \
--batch={args.ft_batch} --workers={args.ft_workers} --gpus={args.ft_gpus} \
--tick={args.ft_tick} --snap={args.ft_snap} --data_max_size={args.ft_data_max_size} --kimg={args.ft_kimg} \
--gamma=5 --aug=ada --neural_rendering_resolution_final=128 --gen_pose_cond=True --gpc_reg_prob=0.8 --metrics=None
"""
print(f"{command} \n")
os.system(command)
os.chdir('..')