import math
import trimesh
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import raymarching
from .utils import custom_meshgrid, get_audio_features, euler_angles_to_matrix, convert_poses

def sample_pdf(bins, weights, n_samples, det=False):
    # This implementation is from NeRF
    # bins: [B, T], old_z_vals
    # weights: [B, T - 1], bin weights.
    # return: [B, n_samples], new_z_vals

    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
    # Take uniform samples
    if det:
        u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
        u = u.expand(list(cdf.shape[:-1]) + [n_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (B, n_samples, 2)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples


def plot_pointcloud(pc, color=None):
    # pc: [N, 3]
    # color: [N, 3/4]
    print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
    pc = trimesh.PointCloud(pc, color)
    # axis
    axes = trimesh.creation.axis(axis_length=4)
    # sphere
    sphere = trimesh.creation.icosphere(radius=1)
    trimesh.Scene([pc, axes, sphere]).show()


class NeRFRenderer(nn.Module):
    def __init__(self, opt):

        super().__init__()

        self.opt = opt
        self.bound = opt.bound
        self.cascade = 1 + math.ceil(math.log2(opt.bound))
        self.grid_size = 128
        self.density_scale = 1

        self.min_near = opt.min_near
        self.density_thresh = opt.density_thresh
        self.density_thresh_torso = opt.density_thresh_torso

        self.exp_eye = opt.exp_eye
        self.test_train = opt.test_train
        self.smooth_lips = opt.smooth_lips

        self.torso = opt.torso
        self.cuda_ray = opt.cuda_ray

        # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
        # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
        aabb_train = torch.FloatTensor([-opt.bound, -opt.bound/2, -opt.bound, opt.bound, opt.bound/2, opt.bound])
        aabb_infer = aabb_train.clone()
        self.register_buffer('aabb_train', aabb_train)
        self.register_buffer('aabb_infer', aabb_infer)

        # individual codes
        self.individual_num = opt.ind_num

        self.individual_dim = opt.ind_dim
        if self.individual_dim > 0:
            self.individual_codes = nn.Parameter(torch.randn(self.individual_num, self.individual_dim) * 0.1) 
        
        if self.torso:
            self.individual_dim_torso = opt.ind_dim_torso
            if self.individual_dim_torso > 0:
                self.individual_codes_torso = nn.Parameter(torch.randn(self.individual_num, self.individual_dim_torso) * 0.1) 

        # optimize camera pose
        self.train_camera = self.opt.train_camera
        if self.train_camera:
            self.camera_dR = nn.Parameter(torch.zeros(self.individual_num, 3)) # euler angle
            self.camera_dT = nn.Parameter(torch.zeros(self.individual_num, 3)) # xyz offset

        # extra state for cuda raymarching
    
        # 3D head density grid
        density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
        density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
        self.register_buffer('density_grid', density_grid)
        self.register_buffer('density_bitfield', density_bitfield)
        self.mean_density = 0
        self.iter_density = 0

        # 2D torso density grid
        if self.torso:
            density_grid_torso = torch.zeros([self.grid_size ** 2]) # [H * H]
            self.register_buffer('density_grid_torso', density_grid_torso)
        self.mean_density_torso = 0

        # step counter
        step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
        self.register_buffer('step_counter', step_counter)
        self.mean_count = 0
        self.local_step = 0
        
        # decay for enc_a
        if self.smooth_lips:
            self.enc_a = None
    
    def forward(self, x, d):
        raise NotImplementedError()

    # separated density and color query (can accelerate non-cuda-ray mode.)
    def density(self, x):
        raise NotImplementedError()

    def color(self, x, d, mask=None, **kwargs):
        raise NotImplementedError()

    def reset_extra_state(self):
        if not self.cuda_ray:
            return 
        # density grid
        self.density_grid.zero_()
        self.mean_density = 0
        self.iter_density = 0
        # step counter
        self.step_counter.zero_()
        self.mean_count = 0
        self.local_step = 0


    def run_cuda(self, rays_o, rays_d, auds, bg_coords, poses, eye=None, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
        # rays_o, rays_d: [B, N, 3], assumes B == 1
        # auds: [B, 16]
        # index: [B]
        # return: image: [B, N, 3], depth: [B, N]

        prefix = rays_o.shape[:-1]
        rays_o = rays_o.contiguous().view(-1, 3)
        rays_d = rays_d.contiguous().view(-1, 3)
        bg_coords = bg_coords.contiguous().view(-1, 2)

        # only add camera offset at training!
        if self.train_camera and (self.training or self.test_train):
            dT = self.camera_dT[index] # [1, 3]
            dR = euler_angles_to_matrix(self.camera_dR[index] / 180 * np.pi + 1e-8).squeeze(0) # [1, 3] --> [3, 3]
            
            rays_o = rays_o + dT
            rays_d = rays_d @ dR

        N = rays_o.shape[0] # N = B * N, in fact
        device = rays_o.device

        results = {}

        # pre-calculate near far
        nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
        nears = nears.detach()
        fars = fars.detach()

        # encode audio
        enc_a = self.encode_audio(auds) # [1, 64]

        if enc_a is not None and self.smooth_lips:
            if self.enc_a is not None:
                _lambda = 0.35
                enc_a = _lambda * self.enc_a + (1 - _lambda) * enc_a
            self.enc_a = enc_a

        
        if self.individual_dim > 0:
            if self.training:
                ind_code = self.individual_codes[index]
            # use a fixed ind code for the unknown test data.
            else:
                ind_code = self.individual_codes[0]
        else:
            ind_code = None

        if self.training:
            # setup counter
            counter = self.step_counter[self.local_step % 16]
            counter.zero_() # set to 0
            self.local_step += 1

            xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
            sigmas, rgbs, amb_aud, amb_eye, uncertainty = self(xyzs, dirs, enc_a, ind_code, eye)
            sigmas = self.density_scale * sigmas

            #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')

            # weights_sum, ambient_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_uncertainty(sigmas, rgbs, ambient.abs().sum(-1), uncertainty, deltas, rays)
            weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_triplane(sigmas, rgbs, amb_aud.abs().sum(-1), amb_eye.abs().sum(-1), uncertainty, deltas, rays)

            # for training only
            results['weights_sum'] = weights_sum
            results['ambient_aud'] = amb_aud_sum
            results['ambient_eye'] = amb_eye_sum
            results['uncertainty'] = uncertainty_sum

            results['rays'] = xyzs, dirs, enc_a, ind_code, eye

        else:
           
            dtype = torch.float32
            
            weights_sum = torch.zeros(N, dtype=dtype, device=device)
            depth = torch.zeros(N, dtype=dtype, device=device)
            image = torch.zeros(N, 3, dtype=dtype, device=device)
            amb_aud_sum = torch.zeros(N, dtype=dtype, device=device)
            amb_eye_sum = torch.zeros(N, dtype=dtype, device=device)
            uncertainty_sum = torch.zeros(N, dtype=dtype, device=device)

            n_alive = N
            rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
            rays_t = nears.clone() # [N]

            step = 0
            
            while step < max_steps:

                # count alive rays 
                n_alive = rays_alive.shape[0]
                
                # exit loop
                if n_alive <= 0:
                    break

                # decide compact_steps
                n_step = max(min(N // n_alive, 8), 1)

                xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)

                sigmas, rgbs, ambients_aud, ambients_eye, uncertainties = self(xyzs, dirs, enc_a, ind_code, eye)
                sigmas = self.density_scale * sigmas

                # raymarching.composite_rays_uncertainty(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh)
                raymarching.composite_rays_triplane(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients_aud, ambients_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh)

                rays_alive = rays_alive[rays_alive >= 0]

                # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')

                step += n_step
            
        torso_results = self.run_torso(rays_o, bg_coords, poses, index, bg_color)
        bg_color = torso_results['bg_color']

        image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
        image = image.view(*prefix, 3)
        image = image.clamp(0, 1)

        depth = torch.clamp(depth - nears, min=0) / (fars - nears)
        depth = depth.view(*prefix)

        amb_aud_sum = amb_aud_sum.view(*prefix)
        amb_eye_sum = amb_eye_sum.view(*prefix)

        results['depth'] = depth
        results['image'] = image # head_image if train, else com_image
        results['ambient_aud'] = amb_aud_sum
        results['ambient_eye'] = amb_eye_sum
        results['uncertainty'] = uncertainty_sum

        return results
    

    def run_torso(self, rays_o, bg_coords, poses, index=0, bg_color=None, **kwargs):
        # rays_o, rays_d: [B, N, 3], assumes B == 1
        # auds: [B, 16]
        # index: [B]
        # return: image: [B, N, 3], depth: [B, N]

        rays_o = rays_o.contiguous().view(-1, 3)
        bg_coords = bg_coords.contiguous().view(-1, 2)

        N = rays_o.shape[0] # N = B * N, in fact
        device = rays_o.device

        results = {}

        # background
        if bg_color is None:
            bg_color = 1

        # first mix torso with background
        if self.torso:
            # torso ind code
            if self.individual_dim_torso > 0:
                if self.training:
                    ind_code_torso = self.individual_codes_torso[index]
                # use a fixed ind code for the unknown test data.
                else:
                    ind_code_torso = self.individual_codes_torso[0]
            else:
                ind_code_torso = None
            
            # 2D density grid for acceleration...
            density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
            occupancy = F.grid_sample(self.density_grid_torso.view(1, 1, self.grid_size, self.grid_size), bg_coords.view(1, -1, 1, 2), align_corners=True).view(-1)
            mask = occupancy > density_thresh_torso

            # masked query of torso
            torso_alpha = torch.zeros([N, 1], device=device)
            torso_color = torch.zeros([N, 3], device=device)

            if mask.any():
                torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, ind_code_torso)

                torso_alpha[mask] = torso_alpha_mask.float()
                torso_color[mask] = torso_color_mask.float()

                results['deform'] = deform
            
            # first mix torso with background
            
            bg_color = torso_color * torso_alpha + bg_color * (1 - torso_alpha)

            results['torso_alpha'] = torso_alpha
            results['torso_color'] = bg_color

            # print(torso_alpha.shape, torso_alpha.max().item(), torso_alpha.min().item())
        
        results['bg_color'] = bg_color
        
        return results


    @torch.no_grad()
    def mark_untrained_grid(self, poses, intrinsic, S=64):
        # poses: [B, 4, 4]
        # intrinsic: [3, 3]

        if not self.cuda_ray:
            return
        
        if isinstance(poses, np.ndarray):
            poses = torch.from_numpy(poses)

        B = poses.shape[0]
        
        fx, fy, cx, cy = intrinsic
        
        X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
        Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
        Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)

        count = torch.zeros_like(self.density_grid)
        poses = poses.to(count.device)

        # 5-level loop, forgive me...

        for xs in X:
            for ys in Y:
                for zs in Z:
                    
                    # construct points
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)
                    coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
                    indices = raymarching.morton3D(coords).long() # [N]
                    world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]

                    # cascading
                    for cas in range(self.cascade):
                        bound = min(2 ** cas, self.bound)
                        half_grid_size = bound / self.grid_size
                        # scale to current cascade's resolution
                        cas_world_xyzs = world_xyzs * (bound - half_grid_size)

                        # split batch to avoid OOM
                        head = 0
                        while head < B:
                            tail = min(head + S, B)

                            # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
                            cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
                            cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
                            
                            # query if point is covered by any camera
                            mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
                            mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
                            mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
                            mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]

                            # update count 
                            count[cas, indices] += mask
                            head += S
    
        # mark untrained grid as -1
        self.density_grid[count == 0] = -1

        #print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')

    @torch.no_grad()
    def update_extra_state(self, decay=0.95, S=128):
        # call before each epoch to update extra states.

        if not self.cuda_ray:
            return 
        
        # use random auds (different expressions should have similar density grid...)
        rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
        auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)

        # encode audio
        enc_a = self.encode_audio(auds)

        ### update density grid
        if not self.torso: # forbid updating head if is training torso...

            tmp_grid = torch.zeros_like(self.density_grid)

            # use a random eye area based on training dataset's statistics...
            if self.exp_eye:
                eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
            else:
                eye = None
            
            # full update
            X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
            Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
            Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)

            for xs in X:
                for ys in Y:
                    for zs in Z:
                        
                        # construct points
                        xx, yy, zz = custom_meshgrid(xs, ys, zs)
                        coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
                        indices = raymarching.morton3D(coords).long() # [N]
                        xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]

                        # cascading
                        for cas in range(self.cascade):
                            bound = min(2 ** cas, self.bound)
                            half_grid_size = bound / self.grid_size
                            # scale to current cascade's resolution
                            cas_xyzs = xyzs * (bound - half_grid_size)
                            # add noise in [-hgs, hgs]
                            cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
                            # query density
                            sigmas = self.density(cas_xyzs, enc_a, eye)['sigma'].reshape(-1).detach().to(tmp_grid.dtype)
                            sigmas *= self.density_scale
                            # assign 
                            tmp_grid[cas, indices] = sigmas
            
            # dilate the density_grid (less aggressive culling)
            tmp_grid = raymarching.morton3D_dilation(tmp_grid)

            # ema update
            valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
            self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
            self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density.
            self.iter_density += 1

            # convert to bitfield
            density_thresh = min(self.mean_density, self.density_thresh)
            self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)

        ### update torso density grid
        if self.torso:
            tmp_grid_torso = torch.zeros_like(self.density_grid_torso)

            # random pose, random ind_code
            rand_idx = random.randint(0, self.poses.shape[0] - 1)
            # pose = convert_poses(self.poses[[rand_idx]]).to(self.density_bitfield.device)
            pose = self.poses[[rand_idx]].to(self.density_bitfield.device)

            if self.opt.ind_dim_torso > 0:
                ind_code = self.individual_codes_torso[[rand_idx]]
            else:
                ind_code = None

            X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
            Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)

            half_grid_size = 1 / self.grid_size

            for xs in X:
                for ys in Y:
                    xx, yy = custom_meshgrid(xs, ys)
                    coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], dim=-1) # [N, 2], in [0, 128)
                    indices = (coords[:, 1] * self.grid_size + coords[:, 0]).long() # NOTE: xy transposed!
                    xys = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 2] in [-1, 1]
                    xys = xys * (1 - half_grid_size)
                    # add noise in [-hgs, hgs]
                    xys += (torch.rand_like(xys) * 2 - 1) * half_grid_size
                    # query density
                    alphas, _, _ = self.forward_torso(xys, pose, ind_code) # [N, 1]
                    
                    # assign 
                    tmp_grid_torso[indices] = alphas.squeeze(1).float()

            # dilate
            tmp_grid_torso = tmp_grid_torso.view(1, 1, self.grid_size, self.grid_size)
            # tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=3, stride=1, padding=1)
            tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=5, stride=1, padding=2)
            tmp_grid_torso = tmp_grid_torso.view(-1)
            
            self.density_grid_torso = torch.maximum(self.density_grid_torso * decay, tmp_grid_torso)
            self.mean_density_torso = torch.mean(self.density_grid_torso).item()

            # density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
            # print(f'[density grid torso] min={self.density_grid_torso.min().item():.4f}, max={self.density_grid_torso.max().item():.4f}, mean={self.mean_density_torso:.4f}, occ_rate={(self.density_grid_torso > density_thresh_torso).sum() / (128**2):.3f}')

        ### update step counter
        total_step = min(16, self.local_step)
        if total_step > 0:
            self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
        self.local_step = 0

        #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')


    @torch.no_grad()
    def get_audio_grid(self,  S=128):
        # call before each epoch to update extra states.

        if not self.cuda_ray:
            return 
        
        # use random auds (different expressions should have similar density grid...)
        rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
        auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)

        # encode audio
        enc_a = self.encode_audio(auds)
        tmp_grid = torch.zeros_like(self.density_grid)

        # use a random eye area based on training dataset's statistics...
        if self.exp_eye:
            eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
        else:
            eye = None
        
        # full update
        X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
        Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
        Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)

        for xs in X:
            for ys in Y:
                for zs in Z:
                    
                    # construct points
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)
                    coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
                    indices = raymarching.morton3D(coords).long() # [N]
                    xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]

                    # cascading
                    for cas in range(self.cascade):
                        bound = min(2 ** cas, self.bound)
                        half_grid_size = bound / self.grid_size
                        # scale to current cascade's resolution
                        cas_xyzs = xyzs * (bound - half_grid_size)
                        # add noise in [-hgs, hgs]
                        cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
                        # query density
                        aud_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_aud'].reshape(-1).detach().to(tmp_grid.dtype)
                        # assign 
                        tmp_grid[cas, indices] = aud_norms
        
        # dilate the density_grid (less aggressive culling)
        tmp_grid = raymarching.morton3D_dilation(tmp_grid)
        return tmp_grid
        # # ema update
        # valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
        # self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])


    @torch.no_grad()
    def get_eye_grid(self,  S=128):
        # call before each epoch to update extra states.

        if not self.cuda_ray:
            return 
        
        # use random auds (different expressions should have similar density grid...)
        rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
        auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)

        # encode audio
        enc_a = self.encode_audio(auds)
        tmp_grid = torch.zeros_like(self.density_grid)

        # use a random eye area based on training dataset's statistics...
        if self.exp_eye:
            eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
        else:
            eye = None
        
        # full update
        X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
        Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
        Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)

        for xs in X:
            for ys in Y:
                for zs in Z:
                    
                    # construct points
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)
                    coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
                    indices = raymarching.morton3D(coords).long() # [N]
                    xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]

                    # cascading
                    for cas in range(self.cascade):
                        bound = min(2 ** cas, self.bound)
                        half_grid_size = bound / self.grid_size
                        # scale to current cascade's resolution
                        cas_xyzs = xyzs * (bound - half_grid_size)
                        # add noise in [-hgs, hgs]
                        cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
                        # query density
                        eye_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_eye'].reshape(-1).detach().to(tmp_grid.dtype)
                        # assign 
                        tmp_grid[cas, indices] = eye_norms
        
        # dilate the density_grid (less aggressive culling)
        tmp_grid = raymarching.morton3D_dilation(tmp_grid)
        return tmp_grid
        # # ema update
        # valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
        # self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])



    def render(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
        # rays_o, rays_d: [B, N, 3], assumes B == 1
        # auds: [B, 29, 16]
        # eye: [B, 1]
        # bg_coords: [1, N, 2]
        # return: pred_rgb: [B, N, 3]

        _run = self.run_cuda
        
        B, N = rays_o.shape[:2]
        device = rays_o.device

        # never stage when cuda_ray
        if staged and not self.cuda_ray:
            # not used
            raise NotImplementedError

        else:
            results = _run(rays_o, rays_d, auds, bg_coords, poses, **kwargs)

        return results
    
    
    def render_torso(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
        # rays_o, rays_d: [B, N, 3], assumes B == 1
        # auds: [B, 29, 16]
        # eye: [B, 1]
        # bg_coords: [1, N, 2]
        # return: pred_rgb: [B, N, 3]

        _run = self.run_torso
        
        B, N = rays_o.shape[:2]
        device = rays_o.device

        # never stage when cuda_ray
        if staged and not self.cuda_ray:
            # not used
            raise NotImplementedError

        else:
            results = _run(rays_o, bg_coords, poses, **kwargs)

        return results