import sys, os sys.path.append(os.getcwd()) from os.path import join as opj import zipfile import json import pickle from tqdm import tqdm import argparse import numpy as np import torch import torch.nn.functional as F from torch import autocast from torchvision.transforms import ToPILImage from diffusers import StableDiffusionImg2ImgPipeline, PNDMScheduler from camera_utils import LookAtPoseSampler, FOV_to_intrinsics def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser(description='Pose-aware dataset generation') parser.add_argument('--strength', default=0.7, type=float) parser.add_argument('--prompt', type=str) parser.add_argument('--data_type', default='ffhq', type=str) # ffhq, cat parser.add_argument('--guidance_scale', default=8, type=float) parser.add_argument('--num_images', default=1000, type=int) parser.add_argument('--sd_model_id', default='stabilityai/stable-diffusion-2-1-base', type=str) parser.add_argument('--num_inference_steps', default=30, type=int) parser.add_argument('--ffhq_eg3d_path', default='pretrained/ffhqrebalanced512-128.pkl', type=str) parser.add_argument('--cat_eg3d_path', default='pretrained/afhqcats512-128.pkl', type=str) parser.add_argument('--ffhq_pivot', default=0.2, type=float) parser.add_argument('--cat_pivot', default=0.05, type=float) parser.add_argument('--pitch_range', default=0.3, type=float) parser.add_argument('--yaw_range', default=0.3, type=float) parser.add_argument('--name_tag', default='', type=str) parser.add_argument('--seed', default=15, type=int) args = parser.parse_args() return args def make_zip(base_dir, prompt, data_type='ffhq', name_tag=''): base_dir = os.path.abspath(base_dir) owd = os.path.abspath(os.getcwd()) os.chdir(base_dir) json_path = opj(base_dir, "dataset.json") zip_path = opj(base_dir, f'data_{data_type}_{prompt.replace(" ", "_")}{name_tag}.zip') zip_file = zipfile.ZipFile(zip_path, "w") with open(json_path, 'r') as file: data = json.load(file) zip_file.write(os.path.relpath(json_path, base_dir), compress_type=zipfile.ZIP_STORED) for label in data['labels']: trg_img_path = label[0] zip_file.write(trg_img_path, compress_type=zipfile.ZIP_STORED) zip_file.close() os.chdir(owd) def pts2pil(pts): pts = (pts + 1) / 2 pts[pts > 1] = 1 pts[pts < 0] = 0 return ToPILImage()(pts[0]) if __name__ == '__main__': args = parse_args() device = "cuda" torch.manual_seed(args.seed) np.random.seed(args.seed) data_type = args.data_type prompt = args.prompt strength = args.strength guidance_scale = args.guidance_scale num_inference_steps = args.num_inference_steps num_images = args.num_images name_tag = args.name_tag # 3DG options ffhq_eg3d_path = args.ffhq_eg3d_path cat_eg3d_path = args.cat_eg3d_path cat_pivot = args.cat_pivot ffhq_pivot = args.ffhq_pivot pitch_range = args.pitch_range yaw_range = args.yaw_range num_frames = 240 truncation_psi = 0.7 truncation_cutoff = 14 fov_deg = 18.837 ft_img_size = 512 # Load 3DG eg3d_path = None if data_type == 'ffhq': eg3d_path = args.ffhq_eg3d_path pivot = ffhq_pivot elif data_type == 'cat': eg3d_path = args.cat_eg3d_path pivot = cat_pivot with open(eg3d_path, 'rb') as f: G = pickle.load(f)['G_ema'].to(device) # torch.nn.Module G.train() for param in G.parameters(): param.requires_grad_(True) # SD options model_id = args.sd_model_id negative_prompt = None eta = 0.0 batch_size = 1 model_inversion = False # Load SD pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, scheduler=PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, set_alpha_to_one=False, steps_offset=1, skip_prk_steps=1), ).to(device) pipe.safety_checker = None print('SD model is loaded') # Outputs directory base_dir = opj(f'./exp_data/data_{data_type}_{prompt.replace(" ", "_")}{name_tag}') src_img_dir = opj(base_dir, "src_imgs") trg_img_dir = opj(base_dir, "trg_imgs") os.makedirs('exp_data', exist_ok=True) os.makedirs(base_dir, exist_ok=True) os.makedirs(src_img_dir, exist_ok=True) os.makedirs(trg_img_dir, exist_ok=True) labels = [] # Fine-tuning 3D generator for i in tqdm(range(num_images)): G.eval() z = torch.from_numpy(np.random.randn(batch_size, G.z_dim)).to(device) intrinsics = FOV_to_intrinsics(fov_deg, device=device) with torch.no_grad(): yaw_idx = np.random.randint(num_frames) pitch_idx = np.random.randint(num_frames) cam_pivot = torch.tensor([0, 0, pivot], device=device) cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7) cam2world_pose = LookAtPoseSampler.sample(np.pi / 2 + yaw_range * np.sin(2 * np.pi * yaw_idx / num_frames), np.pi / 2 - 0.05 + pitch_range * np.cos( 2 * np.pi * pitch_idx / num_frames), cam_pivot, radius=cam_radius, device=device, batch_size=batch_size) conditioning_cam2world_pose = LookAtPoseSampler.sample(np.pi / 2, np.pi / 2, cam_pivot, radius=cam_radius, device=device, batch_size=batch_size) camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9).repeat(batch_size, 1)], 1) conditioning_params = torch.cat( [conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9).repeat(batch_size, 1)], 1) ws = G.mapping(z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) img_pts = G.synthesis(ws, camera_params)['image'] src_img_pts = img_pts.detach() src_img_pts = F.interpolate(src_img_pts, (ft_img_size, ft_img_size), mode='bilinear', align_corners=False) with autocast("cuda"): trg_img_pil = pipe(prompt=prompt, image=src_img_pts, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, )['images'][0] src_idx = f'{i:05d}_src.png' trg_idx = f'{i:05d}_trg.png' src_img_pil_path = opj(src_img_dir, src_idx) trg_img_pil_path = opj(trg_img_dir, trg_idx) src_img_pil = pts2pil(src_img_pts.cpu()) src_img_pil.save(src_img_pil_path) trg_img_pil.save(trg_img_pil_path) label = [trg_img_pil_path.replace(base_dir, '').replace('/trg_', 'trg_'), camera_params[0].tolist()] labels.append(label) json_path = opj(base_dir, "dataset.json") json_data = {'labels': labels} with open(json_path, 'w') as outfile: json.dump(json_data, outfile, indent=4) make_zip(base_dir, prompt, data_type, name_tag)