# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Sample new images from a pre-trained DiT.
"""
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import yaml
import json
import numpy as np
from pathlib import Path
import gin
import importlib
import logging
import cv2
import matplotlib.pyplot as plt


logging.basicConfig(
    format="[%(asctime)s.%(msecs)03d] [%(module)s] [%(levelname)s] | %(message)s",
    datefmt="%H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torch.utils.data import DataLoader

from core.diffusion import create_diffusion
from core.models import DiT_models
from core.dataset import ImageParamsDataset
from core.utils.train_utils import load_model
from core.utils.math_utils import unnormalize_params
from scripts.prepare_data import generate

def main(cfg, generator):
    # Setup PyTorch:
    torch.manual_seed(cfg["seed"])
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model:
    latent_size = cfg["num_params"]
    model = DiT_models[cfg["model"]](input_size=latent_size).to(device)
    # load a custom DiT checkpoint from train.py:
    state_dict = load_model(cfg["ckpt_path"])
    model.load_state_dict(state_dict)
    model.eval()  # important!
    diffusion = create_diffusion(str(cfg["num_sampling_steps"]))

    # Load dataset
    dataset = ImageParamsDataset(cfg["data_root"], cfg["test_file"], cfg["params_dict_file"])
    loader = DataLoader(
        dataset,
        batch_size=cfg["batch_size"],
        shuffle=False,
        num_workers=cfg["num_workers"],
        pin_memory=True,
        drop_last=False
    )
    params_dict = json.load(open(cfg["params_dict_file"]))
    idx = 0
    total_error = np.zeros(cfg["num_params"])
    for x, img_feat, img in loader:
        # sample from random noise, conditioned on image features
        img_feat = img_feat.to(device)
    
        model_kwargs = dict(y=img_feat)

        z = torch.randn(cfg["batch_size"], 1, latent_size, device=device)

        # Sample target params:
        samples = diffusion.p_sample_loop(
            model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
        )
        samples = samples.reshape(cfg["batch_size"], 1, -1)
        samples = samples.squeeze(1).cpu().numpy()
        x = x.squeeze(1).cpu().numpy()
        img = img.cpu().numpy()
        if cfg["run_generate"]:
            # save GT & sampled params & images
            for x_, params, img_ in zip(x, samples, img):
                # generate 3D using sampled params
                params_original = unnormalize_params(params, params_dict)
                save_dir = os.path.join(cfg["save_dir"], "{:05d}".format(idx))
                os.makedirs(save_dir, exist_ok=True)
                save_name = "sampled"
                asset, _ = generate(generator, params_original, seed=cfg["seed"], save_dir=save_dir, save_name=save_name,
                        save_blend=True, save_img=True, save_gif=False, save_mesh=True, 
                        cam_dists=cfg["r_cam_dists"], cam_elevations=cfg["r_cam_elevations"], cam_azimuths=cfg["r_cam_azimuths"], zoff=cfg["r_zoff"], 
                        resolution='256x256', sample=100)
                np.save(os.path.join(save_dir, "params.npy"), params_original)
                print("Generating model using sampled parameters. Saved in {}".format(save_dir))
                # also save GT image & GT params
                x_original = unnormalize_params(x_, params_dict)
                np.save(os.path.join(save_dir, "gt_params.npy"), x_original)
                cv2.imwrite(os.path.join(save_dir, "gt.png"), img_[:,:,::-1])
                idx += 1
        
        # calculate metrics for sampled params & GT params
        error = np.abs(x - samples)
        total_error += error
    
    # print the average error for each parameter
    avg_error = total_error / len(dataset)
    param_names = params_dict.keys()
    for param_name, error in zip(param_names, avg_error):
        print(f"{param_name}: {error:.4f}")
    # plot the error for each parameter
    fig, ax = plt.subplots()
    fig.set_size_inches(20, 15)
    ax.barh(param_names, avg_error)
    ax.set_xlabel("Average Error")
    ax.set_ylabel("Parameters")
    ax.set_title("Average Error for Each Parameter")
    plt.yticks(fontsize=10)
    fig.tight_layout()
    fig.savefig(os.path.join(cfg["save_dir"], "avg_error.png"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()
    with open(args.config) as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
    
    # load the Blender procedural generator
    OBJECTS_PATH = Path(cfg["generator_root"])
    assert OBJECTS_PATH.exists(), OBJECTS_PATH
    generator = None
    for subdir in sorted(list(OBJECTS_PATH.iterdir())):
        clsname = subdir.name.split(".")[0].strip()
        with gin.unlock_config():
            module = importlib.import_module(f"core.assets.{clsname}")
        if hasattr(module, cfg["generator"]):
            generator = getattr(module, cfg["generator"])
            logger.info("Found {} in {}".format(cfg["generator"], subdir))
            break
        logger.debug("{} not found in {}".format(cfg["generator"], subdir))
    if generator is None:
        raise ModuleNotFoundError("{} not Found.".format(cfg["generator"]))
    gen = generator(cfg["seed"])
    # create visualize dir
    os.makedirs(cfg["save_dir"], exist_ok=True)
    main(cfg, gen)