# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

from difflib import unified_diff
import os
import numpy as np
import torch

from . import obj
from . import util

#########################################################################################
# Base mesh class
#
# Minibatch in mesh is supported, as long as each mesh shares the same edge connectivity. 
#########################################################################################
class Mesh:
    def __init__(self, 
                 v_pos=None, 
                 t_pos_idx=None, 
                 v_nrm=None, 
                 t_nrm_idx=None, 
                 v_tex=None, 
                 t_tex_idx=None, 
                 v_tng=None, 
                 t_tng_idx=None, 
                 material=None, 
                 base=None):
        self.v_pos = v_pos
        self.v_nrm = v_nrm
        self.v_tex = v_tex
        self.v_tng = v_tng
        self.t_pos_idx = t_pos_idx
        self.t_nrm_idx = t_nrm_idx
        self.t_tex_idx = t_tex_idx
        self.t_tng_idx = t_tng_idx
        self.material = material

        if base is not None:
            self.copy_none(base)

    def __len__(self):
        return len(self.v_pos)

    def copy_none(self, other):
        if self.v_pos is None:
            self.v_pos = other.v_pos
        if self.t_pos_idx is None:
            self.t_pos_idx = other.t_pos_idx
        if self.v_nrm is None:
            self.v_nrm = other.v_nrm
        if self.t_nrm_idx is None:
            self.t_nrm_idx = other.t_nrm_idx
        if self.v_tex is None:
            self.v_tex = other.v_tex
        if self.t_tex_idx is None:
            self.t_tex_idx = other.t_tex_idx
        if self.v_tng is None:
            self.v_tng = other.v_tng
        if self.t_tng_idx is None:
            self.t_tng_idx = other.t_tng_idx
        if self.material is None:
            self.material = other.material

    def clone(self):
        out = Mesh(base=self)
        if out.v_pos is not None:
            out.v_pos = out.v_pos.clone().detach()
        if out.t_pos_idx is not None:
            out.t_pos_idx = out.t_pos_idx.clone().detach()
        if out.v_nrm is not None:
            out.v_nrm = out.v_nrm.clone().detach()
        if out.t_nrm_idx is not None:
            out.t_nrm_idx = out.t_nrm_idx.clone().detach()
        if out.v_tex is not None:
            out.v_tex = out.v_tex.clone().detach()
        if out.t_tex_idx is not None:
            out.t_tex_idx = out.t_tex_idx.clone().detach()
        if out.v_tng is not None:
            out.v_tng = out.v_tng.clone().detach()
        if out.t_tng_idx is not None:
            out.t_tng_idx = out.t_tng_idx.clone().detach()
        return out
    
    def detach(self):
        return self.clone()

    def extend(self, N: int):
        """
        Create new Mesh class which contains each input mesh N times.

        Args:
            N: number of new copies of each mesh.

        Returns:
            new Mesh object.
        """
        verts = self.v_pos.repeat(N, 1, 1)
        faces = self.t_pos_idx
        uvs = self.v_tex.repeat(N, 1, 1)
        uv_idx = self.t_tex_idx
        mat = self.material

        return make_mesh(verts, faces, uvs, uv_idx, self.material)

    def deform(self, deformation):
        """
        Create new Mesh class which is obtained by performing the deformation to the self.

        Args:
            deformation: tensor with shape (B, V, 3)

        Returns:
            new Mesh object after the deformation.
        """
        assert deformation.shape[1] == self.v_pos.shape[1] and deformation.shape[2] == 3
        verts = self.v_pos + deformation
        return make_mesh(verts, self.t_pos_idx, self.v_tex.repeat(len(verts), 1, 1), self.t_tex_idx, self.material)

    def get_m_to_n(self, m: int, n: int):
        """
        Create new Mesh class with the n-th (included) mesh to the m-th (not included) mesh in the batch.

        Args:
            m: the index of the starting mesh to be contained.
            n: the index of the first mesh not to be contained.
        """
        verts = self.v_pos[m:n, ...]
        faces = self.t_pos_idx
        uvs = self.v_tex[m:n, ...]
        uv_idx = self.t_tex_idx
        mat = self.material

        return make_mesh(verts, faces, uvs, uv_idx, mat)

    def first_n(self, n: int):
        """
        Create new Mesh class with only the first n meshes in the batch.

        Args:
            n: number of meshes to be contained.

        Returns:
            new Mesh object with the first n meshes.
        """
        return self.get_m_to_n(0, n)
        verts = self.v_pos[:n, ...]
        faces = self.t_pos_idx
        uvs = self.v_tex[:n, ...]
        uv_idx = self.t_tex_idx
        mat = self.material

        return make_mesh(verts, faces, uvs, uv_idx, mat)

    def get_n(self, n: int):
        """
        Create new Mesh class with only the n-th meshes in the batch.

        Args:
            n: the index of the mesh to be contained.

        Returns:
            new Mesh object with the n-th mesh.
        """
        verts = self.v_pos[n:n+1, ...]
        faces = self.t_pos_idx
        uvs = self.v_tex[n:n+1, ...]
        uv_idx = self.t_tex_idx
        mat = self.material

        return make_mesh(verts, faces, uvs, uv_idx, mat)


######################################################################################
# Mesh loading helper
######################################################################################
def load_mesh(filename, mtl_override=None):
    name, ext = os.path.splitext(filename)
    if ext == ".obj":
        return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override)
    assert False, "Invalid mesh file extension"

######################################################################################
# Compute AABB
######################################################################################
def aabb(mesh):
    return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values

######################################################################################
# Compute unique edge list from attribute/vertex index list
######################################################################################
def compute_edges(attr_idx, return_inverse=False):
    with torch.no_grad():
        # Create all edges, packed by triangle
        idx = attr_idx[0]
        all_edges = torch.cat((
            torch.stack((idx[:, 0], idx[:, 1]), dim=-1),
            torch.stack((idx[:, 1], idx[:, 2]), dim=-1),
            torch.stack((idx[:, 2], idx[:, 0]), dim=-1),
        ), dim=-1).view(-1, 2)

        # Swap edge order so min index is always first
        order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
        sorted_edges = torch.cat((
            torch.gather(all_edges, 1, order),
            torch.gather(all_edges, 1, 1 - order)
        ), dim=-1)

        # Eliminate duplicates and return inverse mapping
        return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse)

######################################################################################
# Compute unique edge to face mapping from attribute/vertex index list
######################################################################################
def compute_edge_to_face_mapping(attr_idx, return_inverse=False):
    with torch.no_grad():
        # Get unique edges
        # Create all edges, packed by triangle
        idx = attr_idx[0]
        all_edges = torch.cat((
            torch.stack((idx[:, 0], idx[:, 1]), dim=-1),
            torch.stack((idx[:, 1], idx[:, 2]), dim=-1),
            torch.stack((idx[:, 2], idx[:, 0]), dim=-1),
        ), dim=-1).view(-1, 2)

        # Swap edge order so min index is always first
        order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
        sorted_edges = torch.cat((
            torch.gather(all_edges, 1, order),
            torch.gather(all_edges, 1, 1 - order)
        ), dim=-1)

        # Elliminate duplicates and return inverse mapping
        unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)

        tris = torch.arange(idx.shape[0]).repeat_interleave(3).cuda()

        tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()

        # Compute edge to face table
        mask0 = order[:,0] == 0
        mask1 = order[:,0] == 1
        tris_per_edge[idx_map[mask0], 0] = tris[mask0]
        tris_per_edge[idx_map[mask1], 1] = tris[mask1]

        return tris_per_edge

######################################################################################
# Align base mesh to reference mesh:move & rescale to match bounding boxes.
######################################################################################
def unit_size(mesh):
    with torch.no_grad():
        vmin, vmax = aabb(mesh)
        scale = 2 / torch.max(vmax - vmin).item()
        v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin
        v_pos = v_pos * scale                  # Rescale to unit size

        return Mesh(v_pos, base=mesh)

######################################################################################
# Center & scale mesh for rendering
######################################################################################
def center_by_reference(base_mesh, ref_aabb, scale):
    center = (ref_aabb[0] + ref_aabb[1]) * 0.5
    scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item()
    v_pos = (base_mesh.v_pos - center[None, ...]) * scale
    return Mesh(v_pos, base=base_mesh)

######################################################################################
# Simple smooth vertex normal computation
######################################################################################
def auto_normals(imesh):
    batch_size = imesh.v_pos.shape[0]

    i0 = imesh.t_pos_idx[0, :, 0]  # Shape: (F)
    i1 = imesh.t_pos_idx[0, :, 1]  # Shape: (F)
    i2 = imesh.t_pos_idx[0, :, 2]  # Shape: (F)

    v0 = imesh.v_pos[:, i0, :]  # Shape: (B, F, 3)
    v1 = imesh.v_pos[:, i1, :]  # Shape: (B, F, 3)
    v2 = imesh.v_pos[:, i2, :]  # Shape: (B, F, 3)

    face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)  # Shape: (B, F, 3)

    # Splat face normals to vertices
    v_nrm = torch.zeros_like(imesh.v_pos)  # Shape: (B, V, 3)
    v_nrm.scatter_add_(1, i0[None, :, None].repeat(batch_size, 1, 3), face_normals)
    v_nrm.scatter_add_(1, i1[None, :, None].repeat(batch_size, 1, 3), face_normals)
    v_nrm.scatter_add_(1, i2[None, :, None].repeat(batch_size, 1, 3), face_normals)

    # Normalize, replace zero (degenerated) normals with some default value
    v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, 
                        v_nrm, torch.tensor([0.0, 0.0, 1.0], 
                        dtype=torch.float32, device='cuda'))
    v_nrm = util.safe_normalize(v_nrm)

    if torch.is_anomaly_enabled():
        assert torch.all(torch.isfinite(v_nrm))

    return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh)

######################################################################################
# Compute tangent space from texture map coordinates
# Follows http://www.mikktspace.com/ conventions
######################################################################################
def compute_tangents(imesh):
    batch_size = imesh.v_pos.shape[0]

    vn_idx = [None] * 3
    pos = [None] * 3
    tex = [None] * 3
    for i in range(0,3):
        pos[i] = imesh.v_pos[:, imesh.t_pos_idx[0, :, i]]
        tex[i] = imesh.v_tex[:, imesh.t_tex_idx[0, :, i]]
        vn_idx[i] = imesh.t_nrm_idx[..., i:i+1]

    tangents = torch.zeros_like(imesh.v_nrm)
    tansum   = torch.zeros_like(imesh.v_nrm)

    # Compute tangent space for each triangle
    uve1 = tex[1] - tex[0]  # Shape: (B, F, 2)
    uve2 = tex[2] - tex[0]  # Shape: (B, F, 2)
    pe1  = pos[1] - pos[0]  # Shape: (B, F, 3)
    pe2  = pos[2] - pos[0]  # Shape: (B, F, 3)
    
    nom   = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]  # Shape: (B, F, 3)
    denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]  # Shape: (B, F, 1)
    
    # Avoid division by zero for degenerated texture coordinates
    tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6))  # Shape: (B, F, 3)

    # Update all 3 vertices
    for i in range(0,3):
        idx = vn_idx[i].repeat(batch_size, 1, 3)  # Shape: (B, F, 3)
        tangents.scatter_add_(1, idx, tang)       # tangents[n_i] = tangents[n_i] + tang
        tansum.scatter_add_(1, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1
    tangents = tangents / tansum

    # Normalize and make sure tangent is perpendicular to normal
    tangents = util.safe_normalize(tangents)
    tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm)

    if torch.is_anomaly_enabled():
        assert torch.all(torch.isfinite(tangents))

    return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh)

######################################################################################
# Create new Mesh from verts, faces, uvs, and uv_idx. The rest is auto computed.
######################################################################################
def make_mesh(verts, faces, uvs, uv_idx, material):
    """
    Create new Mesh class with given verts, faces, uvs, and uv_idx.

    Args:
        verts: tensor of shape (B, V, 3)
        faces: tensor of shape (1, F, 3)
        uvs: tensor of shape (B, V, 2)
        uv_idx: tensor of shape (1, F, 3)
        material: an Material instance, specifying the material of the mesh.

    Returns:
        new Mesh object.
    """
    assert len(verts.shape) == 3 and len(faces.shape) == 3 and len(uvs.shape) == 3 and len(uv_idx.shape) == 3, "All components must be batched."
    assert faces.shape[0] == 1 and uv_idx.shape[0] == 1, "Every mesh must share the same edge connectivity."
    assert verts.shape[0] == uvs.shape[0], "Batch size must be consistent."
    ret = Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
    ret = auto_normals(ret)
    ret = compute_tangents(ret)
    return ret