import os
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
import trimesh
from einops import rearrange
from huggingface_hub import hf_hub_download
from jaxtyping import Float
from omegaconf import OmegaConf
from PIL import Image
from safetensors.torch import load_model
from torch import Tensor

from spar3d.models.diffusion.gaussian_diffusion import (
    SpacedDiffusion,
    get_named_beta_schedule,
    space_timesteps,
)
from spar3d.models.diffusion.sampler import PointCloudSampler
from spar3d.models.isosurface import MarchingTetrahedraHelper
from spar3d.models.mesh import Mesh
from spar3d.models.utils import (
    BaseModule,
    ImageProcessor,
    convert_data,
    dilate_fill,
    find_class,
    float32_to_uint8_np,
    normalize,
    scale_tensor,
)
from spar3d.utils import (
    create_intrinsic_from_fov_rad,
    default_cond_c2w,
    get_device,
    normalize_pc_bbox,
)

try:
    from texture_baker import TextureBaker
except ImportError:
    import logging

    logging.warning(
        "Could not import texture_baker. Please install it via `pip install texture-baker/`"
    )
    # Exit early to avoid further errors
    raise ImportError("texture_baker not found")


class SPAR3D(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        cond_image_size: int
        isosurface_resolution: int
        isosurface_threshold: float = 10.0
        radius: float = 1.0
        background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
        default_fovy_rad: float = 0.591627
        default_distance: float = 2.2

        camera_embedder_cls: str = ""
        camera_embedder: dict = field(default_factory=dict)

        image_tokenizer_cls: str = ""
        image_tokenizer: dict = field(default_factory=dict)

        point_embedder_cls: str = ""
        point_embedder: dict = field(default_factory=dict)

        tokenizer_cls: str = ""
        tokenizer: dict = field(default_factory=dict)

        backbone_cls: str = ""
        backbone: dict = field(default_factory=dict)

        post_processor_cls: str = ""
        post_processor: dict = field(default_factory=dict)

        decoder_cls: str = ""
        decoder: dict = field(default_factory=dict)

        image_estimator_cls: str = ""
        image_estimator: dict = field(default_factory=dict)

        global_estimator_cls: str = ""
        global_estimator: dict = field(default_factory=dict)

        # Point diffusion modules
        pdiff_camera_embedder_cls: str = ""
        pdiff_camera_embedder: dict = field(default_factory=dict)

        pdiff_image_tokenizer_cls: str = ""
        pdiff_image_tokenizer: dict = field(default_factory=dict)

        pdiff_backbone_cls: str = ""
        pdiff_backbone: dict = field(default_factory=dict)

        scale_factor_xyz: float = 1.0
        scale_factor_rgb: float = 1.0
        bias_xyz: float = 0.0
        bias_rgb: float = 0.0
        train_time_steps: int = 1024
        inference_time_steps: int = 64

        mean_type: str = "epsilon"
        var_type: str = "fixed_small"
        diffu_sched: str = "cosine"
        diffu_sched_exp: float = 12.0
        guidance_scale: float = 3.0
        sigma_max: float = 120.0
        s_churn: float = 3.0

    cfg: Config

    @classmethod
    def from_pretrained(
        cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
    ):
        base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
        if os.path.isdir(os.path.join(base_dir, pretrained_model_name_or_path)):
            config_path = os.path.join(
                base_dir, pretrained_model_name_or_path, config_name
            )
            weight_path = os.path.join(
                base_dir, pretrained_model_name_or_path, weight_name
            )
        else:
            config_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path, filename=config_name
            )
            weight_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path, filename=weight_name
            )

        cfg = OmegaConf.load(config_path)
        OmegaConf.resolve(cfg)
        model = cls(cfg)
        load_model(model, weight_path, strict=False)
        return model

    @property
    def device(self):
        return next(self.parameters()).device

    def configure(self):
        self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
            self.cfg.image_tokenizer
        )
        self.point_embedder = find_class(self.cfg.point_embedder_cls)(
            self.cfg.point_embedder
        )
        self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
        self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
            self.cfg.camera_embedder
        )
        self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
        self.post_processor = find_class(self.cfg.post_processor_cls)(
            self.cfg.post_processor
        )
        self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
        self.image_estimator = find_class(self.cfg.image_estimator_cls)(
            self.cfg.image_estimator
        )
        self.global_estimator = find_class(self.cfg.global_estimator_cls)(
            self.cfg.global_estimator
        )

        # point diffusion modules
        self.pdiff_image_tokenizer = find_class(self.cfg.pdiff_image_tokenizer_cls)(
            self.cfg.pdiff_image_tokenizer
        )
        self.pdiff_camera_embedder = find_class(self.cfg.pdiff_camera_embedder_cls)(
            self.cfg.pdiff_camera_embedder
        )
        self.pdiff_backbone = find_class(self.cfg.pdiff_backbone_cls)(
            self.cfg.pdiff_backbone
        )

        self.bbox: Float[Tensor, "2 3"]
        self.register_buffer(
            "bbox",
            torch.as_tensor(
                [
                    [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
                    [self.cfg.radius, self.cfg.radius, self.cfg.radius],
                ],
                dtype=torch.float32,
            ),
        )
        self.isosurface_helper = MarchingTetrahedraHelper(
            self.cfg.isosurface_resolution,
            os.path.join(
                os.path.dirname(__file__),
                "..",
                "load",
                "tets",
                f"{self.cfg.isosurface_resolution}_tets.npz",
            ),
        )

        self.baker = TextureBaker()
        self.image_processor = ImageProcessor()

        channel_scales = [self.cfg.scale_factor_xyz] * 3
        channel_scales += [self.cfg.scale_factor_rgb] * 3
        channel_biases = [self.cfg.bias_xyz] * 3
        channel_biases += [self.cfg.bias_rgb] * 3
        channel_scales = np.array(channel_scales)
        channel_biases = np.array(channel_biases)

        betas = get_named_beta_schedule(
            self.cfg.diffu_sched, self.cfg.train_time_steps, self.cfg.diffu_sched_exp
        )

        diffusion_kwargs = dict(
            betas=betas,
            model_mean_type=self.cfg.mean_type,
            model_var_type=self.cfg.var_type,
            channel_scales=channel_scales,
            channel_biases=channel_biases,
        )
        self.diffusion_spaced = SpacedDiffusion(
            use_timesteps=space_timesteps(
                self.cfg.train_time_steps,
                "ddim" + str(self.cfg.inference_time_steps),
            ),
            **diffusion_kwargs,
        )
        self.sampler = PointCloudSampler(
            model=self.pdiff_backbone,
            diffusion=self.diffusion_spaced,
            num_points=512,
            point_dim=6,
            guidance_scale=self.cfg.guidance_scale,
            clip_denoised=True,
            sigma_min=1e-3,
            sigma_max=self.cfg.sigma_max,
            s_churn=self.cfg.s_churn,
        )

    def triplane_to_meshes(
        self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
    ) -> list[Mesh]:
        meshes = []
        for i in range(triplanes.shape[0]):
            triplane = triplanes[i]
            grid_vertices = scale_tensor(
                self.isosurface_helper.grid_vertices.to(triplanes.device),
                self.isosurface_helper.points_range,
                self.bbox,
            )

            values = self.query_triplane(grid_vertices, triplane)
            decoded = self.decoder(values, include=["vertex_offset", "density"])
            sdf = decoded["density"] - self.cfg.isosurface_threshold

            deform = decoded["vertex_offset"].squeeze(0)

            mesh: Mesh = self.isosurface_helper(
                sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
            )
            mesh.v_pos = scale_tensor(
                mesh.v_pos, self.isosurface_helper.points_range, self.bbox
            )

            meshes.append(mesh)

        return meshes

    def query_triplane(
        self,
        positions: Float[Tensor, "*B N 3"],
        triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
    ) -> Float[Tensor, "*B N F"]:
        batched = positions.ndim == 3
        if not batched:
            # no batch dimension
            triplanes = triplanes[None, ...]
            positions = positions[None, ...]
        assert triplanes.ndim == 5 and positions.ndim == 3

        positions = scale_tensor(
            positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
        )

        indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
            (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
            dim=-3,
        ).to(triplanes.dtype)
        out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
            rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
            rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
            align_corners=True,
            mode="bilinear",
        )
        out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)

        return out

    def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
        # if batch[rgb_cond] is only one view, add a view dimension
        if len(batch["rgb_cond"].shape) == 4:
            batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
            batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
            batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
            batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
            batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)

        batch_size, n_input_views = batch["rgb_cond"].shape[:2]

        camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
        camera_embeds = self.camera_embedder(**batch)

        pc_embeds = self.point_embedder(batch["pc_cond"])

        input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
            rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
            modulation_cond=camera_embeds,
        )

        input_image_tokens = rearrange(
            input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
        )

        tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)

        cross_tokens = input_image_tokens
        cross_tokens = torch.cat([cross_tokens, pc_embeds], dim=1)

        tokens = self.backbone(
            tokens,
            encoder_hidden_states=cross_tokens,
            modulation_cond=None,
        )

        direct_codes = self.tokenizer.detokenize(tokens)
        scene_codes = self.post_processor(direct_codes)
        return scene_codes, direct_codes

    def forward_pdiff_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        if len(batch["rgb_cond"].shape) == 4:
            batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
            batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
            batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
            batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
            batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)

        _batch_size, n_input_views = batch["rgb_cond"].shape[:2]

        # Camera modulation
        camera_embeds: Float[Tensor, "B Nv Cc"] = self.pdiff_camera_embedder(**batch)

        input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.pdiff_image_tokenizer(
            rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
            modulation_cond=camera_embeds,
        )

        input_image_tokens = rearrange(
            input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
        )

        return input_image_tokens

    def run_image(
        self,
        image: Union[Image.Image, List[Image.Image]],
        bake_resolution: int,
        pointcloud: Optional[Union[List[np.ndarray], np.ndarray, Tensor]] = None,
        remesh: Literal["none", "triangle", "quad"] = "none",
        vertex_count: int = -1,
        estimate_illumination: bool = False,
        return_points: bool = False,
    ) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]:
        if isinstance(image, list):
            rgb_cond = []
            mask_cond = []
            for img in image:
                mask, rgb = self.prepare_image(img)
                mask_cond.append(mask)
                rgb_cond.append(rgb)
            rgb_cond = torch.stack(rgb_cond, 0)
            mask_cond = torch.stack(mask_cond, 0)
            batch_size = rgb_cond.shape[0]
        else:
            mask_cond, rgb_cond = self.prepare_image(image)
            batch_size = 1

        c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
        intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_rad(
            self.cfg.default_fovy_rad,
            self.cfg.cond_image_size,
            self.cfg.cond_image_size,
        )

        batch = {
            "rgb_cond": rgb_cond,
            "mask_cond": mask_cond,
            "c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1),
            "intrinsic_cond": intrinsic.to(self.device)
            .view(1, 1, 3, 3)
            .repeat(batch_size, 1, 1, 1),
            "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device)
            .view(1, 1, 3, 3)
            .repeat(batch_size, 1, 1, 1),
        }

        meshes, global_dict = self.generate_mesh(
            batch,
            bake_resolution,
            pointcloud,
            remesh,
            vertex_count,
            estimate_illumination,
        )

        if return_points:
            point_clouds = []
            for i in range(batch_size):
                xyz = batch["pc_cond"][i, :, :3].cpu().numpy()
                color_rgb = (
                    (batch["pc_cond"][i, :, 3:6] * 255).cpu().numpy().astype(np.uint8)
                )
                pc_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb)
                point_clouds.append(pc_trimesh)
            global_dict["point_clouds"] = point_clouds

        if batch_size == 1:
            return meshes[0], global_dict
        else:
            return meshes, global_dict

    def prepare_image(self, image):
        if image.mode != "RGBA":
            raise ValueError("Image must be in RGBA mode")
        img_cond = (
            torch.from_numpy(
                np.asarray(
                    image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
                ).astype(np.float32)
                / 255.0
            )
            .float()
            .clip(0, 1)
            .to(self.device)
        )
        mask_cond = img_cond[:, :, -1:]
        rgb_cond = torch.lerp(
            torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
            img_cond[:, :, :3],
            mask_cond,
        )

        return mask_cond, rgb_cond

    def generate_mesh(
        self,
        batch,
        bake_resolution: int,
        pointcloud: Optional[Union[List[float], np.ndarray, Tensor]] = None,
        remesh: Literal["none", "triangle", "quad"] = "none",
        vertex_count: int = -1,
        estimate_illumination: bool = False,
    ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
        batch["rgb_cond"] = self.image_processor(
            batch["rgb_cond"], self.cfg.cond_image_size
        )
        batch["mask_cond"] = self.image_processor(
            batch["mask_cond"], self.cfg.cond_image_size
        )

        batch_size = batch["rgb_cond"].shape[0]

        if pointcloud is not None:
            if isinstance(pointcloud, list):
                cond_tensor = torch.tensor(pointcloud).float().cuda().view(-1, 6)
                xyz = cond_tensor[:, :3]
                color_rgb = cond_tensor[:, 3:]
            # Check if point cloud is a numpy array
            elif isinstance(pointcloud, np.ndarray):
                xyz = torch.tensor(pointcloud[:, :3]).float().cuda()
                color_rgb = torch.tensor(pointcloud[:, 3:]).float().cuda()
            else:
                raise ValueError("Invalid point cloud type")

            pointcloud = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(0)
            batch["pc_cond"] = pointcloud

        if "pc_cond" not in batch:
            cond_tokens = self.forward_pdiff_cond(batch)
            sample_iter = self.sampler.sample_batch_progressive(
                batch_size, cond_tokens, device=self.device
            )
            for x in sample_iter:
                samples = x["xstart"]

            denoised_pc = samples.permute(0, 2, 1).float()  # [B, C, N] -> [B, N, C]
            denoised_pc = normalize_pc_bbox(denoised_pc)

            # predict the full 3D conditioned on the denoised point cloud
            batch["pc_cond"] = denoised_pc

        scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)

        global_dict = {}
        if self.image_estimator is not None:
            global_dict.update(
                self.image_estimator(
                    torch.cat([batch["rgb_cond"], batch["mask_cond"]], dim=-1)
                )
            )
        if self.global_estimator is not None and estimate_illumination:
            global_dict.update(self.global_estimator(non_postprocessed_codes))

        global_dict["pointcloud"] = batch["pc_cond"]

        device = get_device()
        with torch.no_grad():
            with (
                torch.autocast(device_type=device, enabled=False)
                if "cuda" in device
                else nullcontext()
            ):
                meshes = self.triplane_to_meshes(scene_codes)

                rets = []
                for i, mesh in enumerate(meshes):
                    # Check for empty mesh
                    if mesh.v_pos.shape[0] == 0:
                        rets.append(trimesh.Trimesh())
                        continue

                    if remesh == "triangle":
                        mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count)
                    elif remesh == "quad":
                        mesh = mesh.quad_remesh(quad_vertex_count=vertex_count)
                    else:
                        if vertex_count > 0:
                            print(
                                "Warning: vertex_count is ignored when remesh is none"
                            )

                    if remesh != "none":
                        print(
                            f"After {remesh} remesh the mesh has {mesh.v_pos.shape[0]} verts and {mesh.t_pos_idx.shape[0]} faces",
                        )
                        mesh.unwrap_uv()

                    # Build textures
                    rast = self.baker.rasterize(
                        mesh.v_tex, mesh.t_pos_idx, bake_resolution
                    )
                    bake_mask = self.baker.get_mask(rast)

                    pos_bake = self.baker.interpolate(
                        mesh.v_pos,
                        rast,
                        mesh.t_pos_idx,
                    )
                    gb_pos = pos_bake[bake_mask]

                    tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
                    decoded = self.decoder(
                        tri_query, exclude=["density", "vertex_offset"]
                    )

                    nrm = self.baker.interpolate(
                        mesh.v_nrm,
                        rast,
                        mesh.t_pos_idx,
                    )
                    gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
                    decoded["normal"] = gb_nrm

                    # Check if any keys in global_dict start with decoded_
                    for k, v in global_dict.items():
                        if k.startswith("decoder_"):
                            decoded[k.replace("decoder_", "")] = v[i]

                    mat_out = {
                        "albedo": decoded["features"],
                        "roughness": decoded["roughness"],
                        "metallic": decoded["metallic"],
                        "normal": normalize(decoded["perturb_normal"]),
                        "bump": None,
                    }

                    for k, v in mat_out.items():
                        if v is None:
                            continue
                        if v.shape[0] == 1:
                            # Skip and directly add a single value
                            mat_out[k] = v[0]
                        else:
                            f = torch.zeros(
                                bake_resolution,
                                bake_resolution,
                                v.shape[-1],
                                dtype=v.dtype,
                                device=v.device,
                            )
                            if v.shape == f.shape:
                                continue
                            if k == "normal":
                                # Use un-normalized tangents here so that larger smaller tris
                                # Don't effect the tangents that much
                                tng = self.baker.interpolate(
                                    mesh.v_tng,
                                    rast,
                                    mesh.t_pos_idx,
                                )
                                gb_tng = tng[bake_mask]
                                gb_tng = F.normalize(gb_tng, dim=-1)
                                gb_btng = F.normalize(
                                    torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1
                                )
                                normal = F.normalize(mat_out["normal"], dim=-1)

                                # Create tangent space matrix and transform normal
                                tangent_matrix = torch.stack(
                                    [gb_tng, gb_btng, gb_nrm], dim=-1
                                )
                                normal_tangent = torch.bmm(
                                    tangent_matrix.transpose(1, 2), normal.unsqueeze(-1)
                                ).squeeze(-1)

                                # Convert from [-1,1] to [0,1] range for storage
                                normal_tangent = (normal_tangent * 0.5 + 0.5).clamp(
                                    0, 1
                                )

                                f[bake_mask] = normal_tangent.view(-1, 3)
                                mat_out["bump"] = f
                            else:
                                f[bake_mask] = v.view(-1, v.shape[-1])
                                mat_out[k] = f

                    def uv_padding(arr):
                        if arr.ndim == 1:
                            return arr
                        return (
                            dilate_fill(
                                arr.permute(2, 0, 1)[None, ...].contiguous(),
                                bake_mask.unsqueeze(0).unsqueeze(0),
                                iterations=bake_resolution // 150,
                            )
                            .squeeze(0)
                            .permute(1, 2, 0)
                            .contiguous()
                        )

                    verts_np = convert_data(mesh.v_pos)
                    faces = convert_data(mesh.t_pos_idx)
                    uvs = convert_data(mesh.v_tex)

                    basecolor_tex = Image.fromarray(
                        float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
                    ).convert("RGB")
                    basecolor_tex.format = "JPEG"

                    metallic = mat_out["metallic"].squeeze().cpu().item()
                    roughness = mat_out["roughness"].squeeze().cpu().item()

                    if "bump" in mat_out and mat_out["bump"] is not None:
                        bump_np = convert_data(uv_padding(mat_out["bump"]))
                        bump_up = np.ones_like(bump_np)
                        bump_up[..., :2] = 0.5
                        bump_up[..., 2:] = 1
                        bump_tex = Image.fromarray(
                            float32_to_uint8_np(
                                bump_np,
                                dither=True,
                                # Do not dither if something is perfectly flat
                                dither_mask=np.all(
                                    bump_np == bump_up, axis=-1, keepdims=True
                                ).astype(np.float32),
                            )
                        ).convert("RGB")
                        bump_tex.format = (
                            "JPEG"  # PNG would be better but the assets are larger
                        )
                    else:
                        bump_tex = None

                    material = trimesh.visual.material.PBRMaterial(
                        baseColorTexture=basecolor_tex,
                        roughnessFactor=roughness,
                        metallicFactor=metallic,
                        normalTexture=bump_tex,
                    )

                    tmesh = trimesh.Trimesh(
                        vertices=verts_np,
                        faces=faces,
                        visual=trimesh.visual.texture.TextureVisuals(
                            uv=uvs, material=material
                        ),
                    )
                    rot = trimesh.transformations.rotation_matrix(
                        np.radians(-90), [1, 0, 0]
                    )
                    tmesh.apply_transform(rot)
                    tmesh.apply_transform(
                        trimesh.transformations.rotation_matrix(
                            np.radians(90), [0, 1, 0]
                        )
                    )

                    tmesh.invert()

                    rets.append(tmesh)

        return rets, global_dict