{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Interactive demo of Cross-view Completion." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n", "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from models.croco import CroCoNet\n", "from ipywidgets import interact, interactive, fixed, interact_manual\n", "import ipywidgets as widgets\n", "import matplotlib.pyplot as plt\n", "import quaternion\n", "import models.masking" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load CroCo model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n", "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n", "msg = model.load_state_dict(ckpt['model'], strict=True)\n", "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", "model = model.eval()\n", "model = model.to(device=device)\n", "print(msg)\n", "\n", "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n", " \"\"\"\n", " Perform Cross-View completion using two input images, specified using Numpy arrays.\n", " \"\"\"\n", " # Replace the mask generator\n", " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n", "\n", " # ImageNet-1k color normalization\n", " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n", " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n", "\n", " normalize_input_colors = True\n", " is_output_normalized = True\n", " with torch.no_grad():\n", " # Cast data to torch\n", " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", "\n", " if normalize_input_colors:\n", " ref_image = (ref_image - imagenet_mean) / imagenet_std\n", " target_image = (target_image - imagenet_mean) / imagenet_std\n", "\n", " out, mask, _ = model(target_image, ref_image)\n", " # # get target\n", " if not is_output_normalized:\n", " predicted_image = model.unpatchify(out)\n", " else:\n", " # The output only contains higher order information,\n", " # we retrieve mean and standard deviation from the actual target image\n", " patchified = model.patchify(target_image)\n", " mean = patchified.mean(dim=-1, keepdim=True)\n", " var = patchified.var(dim=-1, keepdim=True)\n", " pred_renorm = out * (var + 1.e-6)**.5 + mean\n", " predicted_image = model.unpatchify(pred_renorm)\n", "\n", " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n", " masked_target_image = (1 - image_masks) * target_image\n", " \n", " if not reconstruct_unmasked_patches:\n", " # Replace unmasked patches by their actual values\n", " predicted_image = predicted_image * image_masks + masked_target_image\n", "\n", " # Unapply color normalization\n", " if normalize_input_colors:\n", " predicted_image = predicted_image * imagenet_std + imagenet_mean\n", " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n", " \n", " # Cast to Numpy\n", " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", " return masked_target_image, predicted_image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n", "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n", "import habitat_sim\n", "\n", "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n", "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n", "\n", "sim_cfg = habitat_sim.SimulatorConfiguration()\n", "if use_gpu: sim_cfg.gpu_device_id = 0\n", "sim_cfg.scene_id = scene\n", "sim_cfg.load_semantic_mesh = False\n", "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n", "rgb_sensor_spec.uuid = \"color\"\n", "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n", "rgb_sensor_spec.resolution = (224,224)\n", "rgb_sensor_spec.hfov = 56.56\n", "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n", "rgb_sensor_spec.orientation = [0, 0, 0]\n", "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n", "\n", "\n", "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n", "sim = habitat_sim.Simulator(cfg)\n", "if navmesh is not None:\n", " sim.pathfinder.load_nav_mesh(navmesh)\n", "agent = sim.initialize_agent(agent_id=0)\n", "\n", "def sample_random_viewpoint():\n", " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n", " nav_point = sim.pathfinder.get_random_navigable_point()\n", " # Sample a random viewpoint height\n", " viewpoint_height = np.random.uniform(1.0, 1.6)\n", " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n", " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n", " return viewpoint_position, viewpoint_orientation\n", "\n", "def render_viewpoint(position, orientation):\n", " agent_state = habitat_sim.AgentState()\n", " agent_state.position = position\n", " agent_state.rotation = orientation\n", " agent.set_state(agent_state)\n", " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n", " image = viewpoint_observations['color'][:,:,:3]\n", " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n", " return image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sample a random reference view" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ref_position, ref_orientation = sample_random_viewpoint()\n", "ref_image = render_viewpoint(ref_position, ref_orientation)\n", "plt.clf()\n", "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n", "axes[0,0].imshow(ref_image)\n", "for ax in axes.flatten():\n", " ax.set_xticks([])\n", " ax.set_yticks([])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interactive cross-view completion using CroCo" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "reconstruct_unmasked_patches = False\n", "\n", "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n", " R = quaternion.as_rotation_matrix(ref_orientation)\n", " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n", " target_orientation = (ref_orientation\n", " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n", " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n", " \n", " ref_image = render_viewpoint(ref_position, ref_orientation)\n", " target_image = render_viewpoint(target_position, target_orientation)\n", "\n", " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n", "\n", " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n", " axes[0].imshow(ref_image)\n", " axes[0].set_xlabel(\"Reference\")\n", " axes[1].imshow(masked_target_image)\n", " axes[1].set_xlabel(\"Masked target\")\n", " axes[2].imshow(predicted_image)\n", " axes[2].set_xlabel(\"Reconstruction\") \n", " axes[3].imshow(target_image)\n", " axes[3].set_xlabel(\"Target\")\n", " for ax in axes.flatten():\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", "\n", "interact(show_demo,\n", " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n", " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n", " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67" } } }, "nbformat": 4, "nbformat_minor": 2 }