# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Sample new images from a pre-trained DiT. """ import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import argparse import yaml import json import numpy as np from pathlib import Path import gin import importlib import logging import cv2 from huggingface_hub import hf_hub_download logging.basicConfig( format="[%(asctime)s.%(msecs)03d] [%(module)s] [%(levelname)s] | %(message)s", datefmt="%H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from core.diffusion import create_diffusion from core.models import DiT_models from core.utils.train_utils import load_model from core.utils.math_utils import unnormalize_params from scripts.prepare_data import generate from core.utils.dinov2 import Dinov2Model def main(cfg, generator): # Setup PyTorch: torch.manual_seed(cfg["seed"]) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" # Load model: latent_size = cfg["num_params"] model = DiT_models[cfg["model"]](input_size=latent_size).to(device) # load a custom DiT checkpoint from train.py: # download the checkpoint if not found: if not os.path.exists(cfg["ckpt_path"]): model_dir, model_name = os.path.dirname(cfg["ckpt_path"]), os.path.basename(cfg["ckpt_path"]) os.makedirs(model_dir, exist_ok=True) checkpoint_path = hf_hub_download(repo_id="TencentARC/DI-PCG", local_dir=model_dir, filename=model_name) print("Downloading checkpoint {} from Hugging Face Hub...".format(model_name)) print("Loading model from {}".format(cfg["ckpt_path"])) state_dict = load_model(cfg["ckpt_path"]) model.load_state_dict(state_dict) model.eval() # important! diffusion = create_diffusion(str(cfg["num_sampling_steps"])) # feature model feature_model = Dinov2Model() img_names = sorted(os.listdir(cfg["condition_img_dir"])) for name in img_names: img_path = os.path.join(cfg["condition_img_dir"], name) # Load condition image and extract features img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) # pre-process: resize to 256x256 img = cv2.resize(img, (256, 256)) img = np.array(img).astype(np.uint8) img_feat = feature_model.encode_batch_imgs([img], global_feat=False) if len(img_feat.shape) == 2: img_feat = img_feat.unsqueeze(1) # Create sampling noise: z = torch.randn(1, 1, latent_size, device=device) y = img_feat # No classifier-free guidance: model_kwargs = dict(y=y) # Sample target params: samples = diffusion.p_sample_loop( model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device ) samples = samples[0].squeeze(0).cpu().numpy() # unnormalize params params_dict = generator.params_dict params_original = unnormalize_params(samples, params_dict) # save params json.dump(params_original, open("{}/{}_params.txt".format(cfg["save_dir"], name), "w"), default=str) # generate 3D using sampled params asset, _ = generate(generator, params_original, seed=cfg["seed"], save_dir=cfg["save_dir"], save_name=name, save_blend=True, save_img=True, save_untexture_img=True, save_gif=False, save_mesh=True, cam_dists=cfg["r_cam_dists"], cam_elevations=cfg["r_cam_elevations"], cam_azimuths=cfg["r_cam_azimuths"], zoff=cfg["r_zoff"], resolution='720x720', sample=200) print("Generating model using sampled parameters. Saved in {}".format(cfg["save_dir"])) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--remove_bg", type=bool, default=False) args = parser.parse_args() with open(args.config) as f: cfg = yaml.load(f, Loader=yaml.FullLoader) cfg["remove_bg"] = args.remove_bg # load the Blender procedural generator OBJECTS_PATH = Path(cfg["generator_root"]) assert OBJECTS_PATH.exists(), OBJECTS_PATH generator = None for subdir in sorted(list(OBJECTS_PATH.iterdir())): clsname = subdir.name.split(".")[0].strip() with gin.unlock_config(): module = importlib.import_module(f"core.assets.{clsname}") if hasattr(module, cfg["generator"]): generator = getattr(module, cfg["generator"]) logger.info("Found {} in {}".format(cfg["generator"], subdir)) break logger.debug("{} not found in {}".format(cfg["generator"], subdir)) if generator is None: raise ModuleNotFoundError("{} not Found.".format(cfg["generator"])) gen = generator(cfg["seed"]) # create visualize dir os.makedirs(cfg["save_dir"], exist_ok=True) main(cfg, gen)