{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interactive demo of Cross-view Completion."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
    "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from models.croco import CroCoNet\n",
    "from ipywidgets import interact, interactive, fixed, interact_manual\n",
    "import ipywidgets as widgets\n",
    "import matplotlib.pyplot as plt\n",
    "import quaternion\n",
    "import models.masking"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load CroCo model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n",
    "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n",
    "msg = model.load_state_dict(ckpt['model'], strict=True)\n",
    "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
    "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
    "model = model.eval()\n",
    "model = model.to(device=device)\n",
    "print(msg)\n",
    "\n",
    "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n",
    "    \"\"\"\n",
    "    Perform Cross-View completion using two input images, specified using Numpy arrays.\n",
    "    \"\"\"\n",
    "    # Replace the mask generator\n",
    "    model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n",
    "\n",
    "    # ImageNet-1k color normalization\n",
    "    imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n",
    "    imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n",
    "\n",
    "    normalize_input_colors = True\n",
    "    is_output_normalized = True\n",
    "    with torch.no_grad():\n",
    "        # Cast data to torch\n",
    "        target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
    "        ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
    "\n",
    "        if normalize_input_colors:\n",
    "            ref_image = (ref_image - imagenet_mean) / imagenet_std\n",
    "            target_image = (target_image - imagenet_mean) / imagenet_std\n",
    "\n",
    "        out, mask, _ = model(target_image, ref_image)\n",
    "        # # get target\n",
    "        if not is_output_normalized:\n",
    "            predicted_image = model.unpatchify(out)\n",
    "        else:\n",
    "            # The output only contains higher order information,\n",
    "            # we retrieve mean and standard deviation from the actual target image\n",
    "            patchified = model.patchify(target_image)\n",
    "            mean = patchified.mean(dim=-1, keepdim=True)\n",
    "            var = patchified.var(dim=-1, keepdim=True)\n",
    "            pred_renorm = out * (var + 1.e-6)**.5 + mean\n",
    "            predicted_image = model.unpatchify(pred_renorm)\n",
    "\n",
    "        image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n",
    "        masked_target_image = (1 - image_masks) * target_image\n",
    "      \n",
    "        if not reconstruct_unmasked_patches:\n",
    "            # Replace unmasked patches by their actual values\n",
    "            predicted_image = predicted_image * image_masks + masked_target_image\n",
    "\n",
    "        # Unapply color normalization\n",
    "        if normalize_input_colors:\n",
    "            predicted_image = predicted_image * imagenet_std + imagenet_mean\n",
    "            masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n",
    "        \n",
    "        # Cast to Numpy\n",
    "        masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
    "        predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
    "        return masked_target_image, predicted_image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n",
    "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n",
    "import habitat_sim\n",
    "\n",
    "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n",
    "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n",
    "\n",
    "sim_cfg = habitat_sim.SimulatorConfiguration()\n",
    "if use_gpu: sim_cfg.gpu_device_id = 0\n",
    "sim_cfg.scene_id = scene\n",
    "sim_cfg.load_semantic_mesh = False\n",
    "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n",
    "rgb_sensor_spec.uuid = \"color\"\n",
    "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n",
    "rgb_sensor_spec.resolution = (224,224)\n",
    "rgb_sensor_spec.hfov = 56.56\n",
    "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n",
    "rgb_sensor_spec.orientation = [0, 0, 0]\n",
    "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n",
    "\n",
    "\n",
    "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n",
    "sim = habitat_sim.Simulator(cfg)\n",
    "if navmesh is not None:\n",
    "    sim.pathfinder.load_nav_mesh(navmesh)\n",
    "agent = sim.initialize_agent(agent_id=0)\n",
    "\n",
    "def sample_random_viewpoint():\n",
    "    \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n",
    "    nav_point = sim.pathfinder.get_random_navigable_point()\n",
    "    # Sample a random viewpoint height\n",
    "    viewpoint_height = np.random.uniform(1.0, 1.6)\n",
    "    viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n",
    "    viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n",
    "    return viewpoint_position, viewpoint_orientation\n",
    "\n",
    "def render_viewpoint(position, orientation):\n",
    "    agent_state = habitat_sim.AgentState()\n",
    "    agent_state.position = position\n",
    "    agent_state.rotation = orientation\n",
    "    agent.set_state(agent_state)\n",
    "    viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n",
    "    image = viewpoint_observations['color'][:,:,:3]\n",
    "    image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n",
    "    return image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample a random reference view"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_position, ref_orientation = sample_random_viewpoint()\n",
    "ref_image = render_viewpoint(ref_position, ref_orientation)\n",
    "plt.clf()\n",
    "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n",
    "axes[0,0].imshow(ref_image)\n",
    "for ax in axes.flatten():\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interactive cross-view completion using CroCo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reconstruct_unmasked_patches = False\n",
    "\n",
    "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n",
    "    R = quaternion.as_rotation_matrix(ref_orientation)\n",
    "    target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n",
    "    target_orientation = (ref_orientation\n",
    "         * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n",
    "         * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n",
    "    \n",
    "    ref_image = render_viewpoint(ref_position, ref_orientation)\n",
    "    target_image = render_viewpoint(target_position, target_orientation)\n",
    "\n",
    "    masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n",
    "\n",
    "    fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n",
    "    axes[0].imshow(ref_image)\n",
    "    axes[0].set_xlabel(\"Reference\")\n",
    "    axes[1].imshow(masked_target_image)\n",
    "    axes[1].set_xlabel(\"Masked target\")\n",
    "    axes[2].imshow(predicted_image)\n",
    "    axes[2].set_xlabel(\"Reconstruction\")        \n",
    "    axes[3].imshow(target_image)\n",
    "    axes[3].set_xlabel(\"Target\")\n",
    "    for ax in axes.flatten():\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "\n",
    "interact(show_demo,\n",
    "        masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n",
    "        x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
    "        y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
    "        z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
    "        panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n",
    "        elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}