# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Sample new images from a pre-trained DiT. """ import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import argparse import yaml import json import numpy as np from pathlib import Path import gin import importlib import logging import cv2 import matplotlib.pyplot as plt logging.basicConfig( format="[%(asctime)s.%(msecs)03d] [%(module)s] [%(levelname)s] | %(message)s", datefmt="%H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.utils.data import DataLoader from core.diffusion import create_diffusion from core.models import DiT_models from core.dataset import ImageParamsDataset from core.utils.train_utils import load_model from core.utils.math_utils import unnormalize_params from scripts.prepare_data import generate def main(cfg, generator): # Setup PyTorch: torch.manual_seed(cfg["seed"]) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" # Load model: latent_size = cfg["num_params"] model = DiT_models[cfg["model"]](input_size=latent_size).to(device) # load a custom DiT checkpoint from train.py: state_dict = load_model(cfg["ckpt_path"]) model.load_state_dict(state_dict) model.eval() # important! diffusion = create_diffusion(str(cfg["num_sampling_steps"])) # Load dataset dataset = ImageParamsDataset(cfg["data_root"], cfg["test_file"], cfg["params_dict_file"]) loader = DataLoader( dataset, batch_size=cfg["batch_size"], shuffle=False, num_workers=cfg["num_workers"], pin_memory=True, drop_last=False ) params_dict = json.load(open(cfg["params_dict_file"])) idx = 0 total_error = np.zeros(cfg["num_params"]) for x, img_feat, img in loader: # sample from random noise, conditioned on image features img_feat = img_feat.to(device) model_kwargs = dict(y=img_feat) z = torch.randn(cfg["batch_size"], 1, latent_size, device=device) # Sample target params: samples = diffusion.p_sample_loop( model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device ) samples = samples.reshape(cfg["batch_size"], 1, -1) samples = samples.squeeze(1).cpu().numpy() x = x.squeeze(1).cpu().numpy() img = img.cpu().numpy() if cfg["run_generate"]: # save GT & sampled params & images for x_, params, img_ in zip(x, samples, img): # generate 3D using sampled params params_original = unnormalize_params(params, params_dict) save_dir = os.path.join(cfg["save_dir"], "{:05d}".format(idx)) os.makedirs(save_dir, exist_ok=True) save_name = "sampled" asset, _ = generate(generator, params_original, seed=cfg["seed"], save_dir=save_dir, save_name=save_name, save_blend=True, save_img=True, save_gif=False, save_mesh=True, cam_dists=cfg["r_cam_dists"], cam_elevations=cfg["r_cam_elevations"], cam_azimuths=cfg["r_cam_azimuths"], zoff=cfg["r_zoff"], resolution='256x256', sample=100) np.save(os.path.join(save_dir, "params.npy"), params_original) print("Generating model using sampled parameters. Saved in {}".format(save_dir)) # also save GT image & GT params x_original = unnormalize_params(x_, params_dict) np.save(os.path.join(save_dir, "gt_params.npy"), x_original) cv2.imwrite(os.path.join(save_dir, "gt.png"), img_[:,:,::-1]) idx += 1 # calculate metrics for sampled params & GT params error = np.abs(x - samples) total_error += error # print the average error for each parameter avg_error = total_error / len(dataset) param_names = params_dict.keys() for param_name, error in zip(param_names, avg_error): print(f"{param_name}: {error:.4f}") # plot the error for each parameter fig, ax = plt.subplots() fig.set_size_inches(20, 15) ax.barh(param_names, avg_error) ax.set_xlabel("Average Error") ax.set_ylabel("Parameters") ax.set_title("Average Error for Each Parameter") plt.yticks(fontsize=10) fig.tight_layout() fig.savefig(os.path.join(cfg["save_dir"], "avg_error.png")) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) args = parser.parse_args() with open(args.config) as f: cfg = yaml.load(f, Loader=yaml.FullLoader) # load the Blender procedural generator OBJECTS_PATH = Path(cfg["generator_root"]) assert OBJECTS_PATH.exists(), OBJECTS_PATH generator = None for subdir in sorted(list(OBJECTS_PATH.iterdir())): clsname = subdir.name.split(".")[0].strip() with gin.unlock_config(): module = importlib.import_module(f"core.assets.{clsname}") if hasattr(module, cfg["generator"]): generator = getattr(module, cfg["generator"]) logger.info("Found {} in {}".format(cfg["generator"], subdir)) break logger.debug("{} not found in {}".format(cfg["generator"], subdir)) if generator is None: raise ModuleNotFoundError("{} not Found.".format(cfg["generator"])) gen = generator(cfg["seed"]) # create visualize dir os.makedirs(cfg["save_dir"], exist_ok=True) main(cfg, gen)