# python3.8
"""Contains the implementation of generator described in EG3D."""

import torch
import torch.nn as nn
import numpy as np
from models.utils.official_stylegan2_model_helper import MappingNetwork
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
from models.utils.eg3d_superres import SuperresolutionHybrid2X
from models.utils.eg3d_superres import SuperresolutionHybrid4X
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
from models.rendering.renderer import Renderer
from models.rendering.feature_extractor import FeatureExtractor
from models.volumegan_generator import FeatureVolume
from models.volumegan_generator import PositionEncoder


class EG3DGeneratorFV(nn.Module):

    def __init__(
            self,
            # Input latent (Z) dimensionality.
            z_dim,
            # Conditioning label (C) dimensionality.
            c_dim,
            # Intermediate latent (W) dimensionality.
            w_dim,
            # Final output image resolution.
            img_resolution,
            # Number of output color channels.
            img_channels,
            # Number of fp16 layers of SR Network.
            sr_num_fp16_res=0,
            # Arguments for MappingNetwork.
            mapping_kwargs={},
            # Arguments for rendering.
            rendering_kwargs={},
            # Arguments for SuperResolution Network.
            sr_kwargs={},
            # Configs for FeatureVolume.
            fv_cfg=dict(feat_res=32,
                        init_res=4,
                        base_channels=256,
                        output_channels=32,
                        w_dim=512),
            # Configs for position encoder.
            embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10),
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.img_resolution = img_resolution
        self.img_channels = img_channels

        # Set up mapping network.
        # Here `num_ws = 2`: one for FeatureVolume Network injection and one for
        # post_neural_renderer injection.
        num_ws = 2
        self.mapping_network = MappingNetwork(z_dim=z_dim,
                                              c_dim=c_dim,
                                              w_dim=w_dim,
                                              num_ws=num_ws,
                                              **mapping_kwargs)

        # Set up the overall renderer.
        self.renderer = Renderer()

        # Set up the feature extractor.
        self.feature_extractor = FeatureExtractor(ref_mode='feature_volume')

        # Set up the reference representation generator.
        self.ref_representation_generator = FeatureVolume(**fv_cfg)

        # Set up the position encoder.
        self.position_encoder = PositionEncoder(**embed_cfg)

        # Set up the post module in the feature extractor.
        self.post_module = None

        # Set up the post neural renderer.
        self.post_neural_renderer = None
        sr_kwargs_total = dict(
            channels=32,
            img_resolution=img_resolution,
            sr_num_fp16_res=sr_num_fp16_res,
            sr_antialias=rendering_kwargs['sr_antialias'],)
        sr_kwargs_total.update(**sr_kwargs)
        if img_resolution == 128:
            self.post_neural_renderer = SuperresolutionHybrid2X(
                **sr_kwargs_total)
        elif img_resolution == 256:
            self.post_neural_renderer = SuperresolutionHybrid4X(
                **sr_kwargs_total)
        elif img_resolution == 512:
            self.post_neural_renderer = SuperresolutionHybrid8XDC(
                **sr_kwargs_total)
        else:
            raise TypeError(f'Unsupported image resolution: {img_resolution}!')

        # Set up the fully-connected layer head.
        self.fc_head = OSGDecoder(
            32, {
                'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
                'decoder_output_dim': 32
            })

        # Set up some rendering related arguments.
        self.neural_rendering_resolution = rendering_kwargs.get(
            'resolution', 64)
        self.rendering_kwargs = rendering_kwargs

    def mapping(self,
                z,
                c,
                truncation_psi=1,
                truncation_cutoff=None,
                update_emas=False):
        if self.rendering_kwargs['c_gen_conditioning_zero']:
            c = torch.zeros_like(c)
        return self.mapping_network(z,
                                    c *
                                    self.rendering_kwargs.get('c_scale', 0),
                                    truncation_psi=truncation_psi,
                                    truncation_cutoff=truncation_cutoff,
                                    update_emas=update_emas)

    def synthesis(self,
                  wp,
                  c,
                  neural_rendering_resolution=None,
                  update_emas=False,
                  **synthesis_kwargs):
        cam2world_matrix = c[:, :16].view(-1, 4, 4)
        if self.rendering_kwargs.get('random_pose', False):
            cam2world_matrix = None

        if neural_rendering_resolution is None:
            neural_rendering_resolution = self.neural_rendering_resolution
        else:
            self.neural_rendering_resolution = neural_rendering_resolution

        feature_volume = self.ref_representation_generator(wp)

        rendering_result = self.renderer(
            wp=wp,
            feature_extractor=self.feature_extractor,
            rendering_options=self.rendering_kwargs,
            cam2world_matrix=cam2world_matrix,
            position_encoder=self.position_encoder,
            ref_representation=feature_volume,
            post_module=self.post_module,
            fc_head=self.fc_head)

        feature_samples = rendering_result['composite_rgb']
        depth_samples = rendering_result['composite_depth']

        # Reshape to keep consistent with 'raw' neural-rendered image.
        N = wp.shape[0]
        H = W = self.neural_rendering_resolution
        feature_image = feature_samples.permute(0, 2, 1).reshape(
            N, feature_samples.shape[-1], H, W).contiguous()
        depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)

        # Run the post neural renderer to get final image.
        # Here, the post neural renderer is a super-resolution network.
        rgb_image = feature_image[:, :3]
        sr_image = self.post_neural_renderer(
            rgb_image,
            feature_image,
            wp,
            noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
            **{
                k: synthesis_kwargs[k]
                for k in synthesis_kwargs.keys() if k != 'noise_mode'
            })

        return {
            'image': sr_image,
            'image_raw': rgb_image,
            'image_depth': depth_image
        }

    def sample(self,
               coordinates,
               directions,
               z,
               c,
               truncation_psi=1,
               truncation_cutoff=None,
               update_emas=False):
        # Compute RGB features, density for arbitrary 3D coordinates.
        # Mostly used for extracting shapes.
        wp = self.mapping_network(z,
                          c,
                          truncation_psi=truncation_psi,
                          truncation_cutoff=truncation_cutoff,
                          update_emas=update_emas)
        feature_volume = self.ref_representation_generator(wp)
        result = self.renderer.get_sigma_rgb(
            wp=wp,
            points=coordinates,
            feature_extractor=self.feature_extractor,
            fc_head=self.fc_head,
            rendering_options=self.rendering_kwargs,
            ref_representation=feature_volume,
            position_encoder=self.position_encoder,
            post_module=self.post_module,
            ray_dirs=directions)

        return result

    def sample_mixed(self,
                     coordinates,
                     directions,
                     wp):
        # Same as function `self.sample()`, but expects latent vectors 'wp'
        # instead of Gaussian noise 'z'.
        feature_volume = self.ref_representation_generator(wp)
        result = self.renderer.get_sigma_rgb(
            wp=wp,
            points=coordinates,
            feature_extractor=self.feature_extractor,
            fc_head=self.fc_head,
            rendering_options=self.rendering_kwargs,
            ref_representation=feature_volume,
            position_encoder=self.position_encoder,
            post_module=self.post_module,
            ray_dirs=directions)

        return result

    def forward(self,
                z,
                c,
                c_swapped=None,      # `c_swapped` is swapped pose conditioning.
                style_mixing_prob=0,
                truncation_psi=1,
                truncation_cutoff=None,
                neural_rendering_resolution=None,
                update_emas=False,
                sample_mixed=False,
                coordinates=None,
                **synthesis_kwargs):

        # Render a batch of generated images.
        c_wp = c.clone()
        if c_swapped is not None:
            c_wp = c_swapped.clone()
        wp = self.mapping_network(z,
                                  c_wp,
                                  truncation_psi=truncation_psi,
                                  truncation_cutoff=truncation_cutoff,
                                  update_emas=update_emas)
        if style_mixing_prob > 0:
            cutoff = torch.empty([], dtype=torch.int64,
                                 device=wp.device).random_(1, wp.shape[1])
            cutoff = torch.where(
                torch.rand([], device=wp.device) < style_mixing_prob, cutoff,
                torch.full_like(cutoff, wp.shape[1]))
            wp[:, cutoff:] = self.mapping_network(
                torch.randn_like(z), c, update_emas=update_emas)[:, cutoff:]
        if not sample_mixed:
            gen_output = self.synthesis(
                wp,
                c,
                update_emas=update_emas,
                neural_rendering_resolution=neural_rendering_resolution,
                **synthesis_kwargs)

            return {
                'wp': wp,
                'gen_output': gen_output,
            }

        else:
            # Only for density regularization in training process.
            assert coordinates is not None
            sample_sigma = self.sample_mixed(coordinates,
                                             torch.randn_like(coordinates),
                                             wp)['sigma']

            return {
                'wp': wp,
                'sample_sigma': sample_sigma
            }


class OSGDecoder(nn.Module):
    """Defines fully-connected layer head in EG3D."""
    def __init__(self, n_features, options):
        super().__init__()
        self.hidden_dim = 64

        self.net = nn.Sequential(
            FullyConnectedLayer(n_features,
                                self.hidden_dim,
                                lr_multiplier=options['decoder_lr_mul']),
            nn.Softplus(),
            FullyConnectedLayer(self.hidden_dim,
                                1 + options['decoder_output_dim'],
                                lr_multiplier=options['decoder_lr_mul']))

    def forward(self, point_features, wp=None, dirs=None):
        # point_features.shape: [N, C, M, 1].
        point_features = point_features.squeeze(-1)
        point_features = point_features.permute(0, 2, 1)
        x = point_features

        N, M, C = x.shape
        x = x.reshape(N * M, C)

        x = self.net(x)
        x = x.reshape(N, M, -1)

        # Uses sigmoid clamping from MipNeRF
        rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
        sigma = x[..., 0:1]

        return {'rgb': rgb, 'sigma': sigma}