import os import yaml import numpy as np from PIL import Image import rembg import importlib import torch import tempfile import json import spaces from core.models import DiT_models from core.diffusion import create_diffusion from core.utils.dinov2 import Dinov2Model from core.utils.math_utils import unnormalize_params from huggingface_hub import hf_hub_download # Setup PyTorch: torch.set_grad_enabled(False) device = torch.device('cuda') # Define the cache directory for model files #model_cache_dir = './ckpts/' #os.makedirs(model_cache_dir, exist_ok=True) # load generators & models generators_choices = ["chair", "table", "vase", "basket", "flower", "dandelion"] factory_names = ["ChairFactory", "TableDiningFactory", "VaseFactory", "BasketBaseFactory", "FlowerFactory", "DandelionFactory"] generator_path = "./core/assets/" generators, configs, models = [], [], [] for category, factory in zip(generators_choices, factory_names): # load generator module = importlib.import_module(f"core.assets.{category}") gen = getattr(module, factory) generator = gen(0) generators.append(generator) # load configs config_path = f"./configs/demo/{category}_demo.yaml" with open(config_path) as f: cfg = yaml.load(f, Loader=yaml.FullLoader) configs.append(cfg) # load models latent_size = cfg["num_params"] model = DiT_models[cfg["model"]](input_size=latent_size).to(device) # load a custom DiT checkpoint from train.py: # download the checkpoint if not found: if not os.path.exists(cfg["ckpt_path"]): model_dir, model_name = os.path.dirname(cfg["ckpt_path"]), os.path.basename(cfg["ckpt_path"]) os.makedirs(model_dir, exist_ok=True) checkpoint_path = hf_hub_download(repo_id="TencentARC/DI-PCG", local_dir=model_dir, filename=model_name) print("Downloading checkpoint {} from Hugging Face Hub...".format(model_name)) print("Loading model from {}".format(cfg["ckpt_path"])) state_dict = torch.load(cfg["ckpt_path"], map_location=lambda storage, loc: storage) if "ema" in state_dict: # supports checkpoints from train.py state_dict = state_dict["ema"] model.load_state_dict(state_dict) model.eval() models.append(model) diffusion = create_diffusion(str(cfg["num_sampling_steps"])) # feature model feature_model = Dinov2Model() def check_input_image(input_image): if input_image is None: raise gr.Error("No image uploaded!") def preprocess(input_image, do_remove_background): # resize if input_image.size[0] != 256 or input_image.size[1] != 256: input_image = input_image.resize((256, 256)) # remove background if do_remove_background: processed_image = rembg.remove(np.array(input_image)) # white background else: processed_image = input_image return processed_image @spaces.GPU def sample(image, seed, category): # seed np.random.seed(seed) torch.manual_seed(seed) # generator & model idx = generators_choices.index(category) generator, cfg, model = generators[idx], configs[idx], models[idx] # encode condition image feature # convert RGBA images to RGB, white background input_image_np = np.array(image) mask = input_image_np[:, :, -1:] > 0 input_image_np = input_image_np[:, :, :3] * mask + 255 * (1 - mask) image = input_image_np.astype(np.uint8) img_feat = feature_model.encode_batch_imgs([np.array(image)], global_feat=False) # Create sampling noise: latent_size = int(cfg['num_params']) z = torch.randn(1, 1, latent_size, device=device) y = img_feat # No classifier-free guidance: model_kwargs = dict(y=y) # Sample target params: samples = diffusion.p_sample_loop( model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device ) samples = samples[0].squeeze(0).cpu().numpy() # unnormalize params params_dict = generator.params_dict params_original = unnormalize_params(samples, params_dict) mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False).name params_fpath = tempfile.NamedTemporaryFile(suffix=f".npy", delete=False).name np.save(params_fpath, params_original) print(mesh_fpath) print(params_fpath) # generate 3D using sampled params - TODO: this is a hacky way to go through PCG pipeline, avoiding conflict with gradio command = f"python ./scripts/generate.py --config ./configs/demo/{category}_demo.yaml --output_path {mesh_fpath} --seed {seed} --params_path {params_fpath}" os.system(command) return mesh_fpath import gradio as gr _HEADER_ = '''