# Interactive demo of Cross-view Completion.

In [None]:
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

In [None]:
import torch
import numpy as np
from models.croco import CroCoNet
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import matplotlib.pyplot as plt
import quaternion
import models.masking

### Load CroCo model

In [None]:
ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
model = CroCoNet( **ckpt.get('croco_kwargs',{}))
msg = model.load_state_dict(ckpt['model'], strict=True)
use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0
device = torch.device('cuda:0' if use_gpu else 'cpu')
model = model.eval()
model = model.to(device=device)
print(msg)

def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):
 """
 Perform Cross-View completion using two input images, specified using Numpy arrays.
 """
 # Replace the mask generator
 model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)

 # ImageNet-1k color normalization
 imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)
 imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)

 normalize_input_colors = True
 is_output_normalized = True
 with torch.no_grad():
 # Cast data to torch
 target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]
 ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]

 if normalize_input_colors:
 ref_image = (ref_image - imagenet_mean) / imagenet_std
 target_image = (target_image - imagenet_mean) / imagenet_std

 out, mask, _ = model(target_image, ref_image)
 # # get target
 if not is_output_normalized:
 predicted_image = model.unpatchify(out)
 else:
 # The output only contains higher order information,
 # we retrieve mean and standard deviation from the actual target image
 patchified = model.patchify(target_image)
 mean = patchified.mean(dim=-1, keepdim=True)
 var = patchified.var(dim=-1, keepdim=True)
 pred_renorm = out * (var + 1.e-6)**.5 + mean
 predicted_image = model.unpatchify(pred_renorm)

 image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])
 masked_target_image = (1 - image_masks) * target_image
 
 if not reconstruct_unmasked_patches:
 # Replace unmasked patches by their actual values
 predicted_image = predicted_image * image_masks + masked_target_image

 # Unapply color normalization
 if normalize_input_colors:
 predicted_image = predicted_image * imagenet_std + imagenet_mean
 masked_target_image = masked_target_image * imagenet_std + imagenet_mean
 
 # Cast to Numpy
 masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)
 predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)
 return masked_target_image, predicted_image

### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)

In [None]:
import os
os.environ["MAGNUM_LOG"]="quiet"
os.environ["HABITAT_SIM_LOG"]="quiet"
import habitat_sim

scene = "habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb"
navmesh = "habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh"

sim_cfg = habitat_sim.SimulatorConfiguration()
if use_gpu: sim_cfg.gpu_device_id = 0
sim_cfg.scene_id = scene
sim_cfg.load_semantic_mesh = False
rgb_sensor_spec = habitat_sim.CameraSensorSpec()
rgb_sensor_spec.uuid = "color"
rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
rgb_sensor_spec.resolution = (224,224)
rgb_sensor_spec.hfov = 56.56
rgb_sensor_spec.position = [0.0, 0.0, 0.0]
rgb_sensor_spec.orientation = [0, 0, 0]
agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])


cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
sim = habitat_sim.Simulator(cfg)
if navmesh is not None:
 sim.pathfinder.load_nav_mesh(navmesh)
agent = sim.initialize_agent(agent_id=0)

def sample_random_viewpoint():
 """ Sample a random viewpoint using the navmesh """
 nav_point = sim.pathfinder.get_random_navigable_point()
 # Sample a random viewpoint height
 viewpoint_height = np.random.uniform(1.0, 1.6)
 viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
 viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)
 return viewpoint_position, viewpoint_orientation

def render_viewpoint(position, orientation):
 agent_state = habitat_sim.AgentState()
 agent_state.position = position
 agent_state.rotation = orientation
 agent.set_state(agent_state)
 viewpoint_observations = sim.get_sensor_observations(agent_ids=0)
 image = viewpoint_observations['color'][:,:,:3]
 image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)
 return image

### Sample a random reference view

In [None]:
ref_position, ref_orientation = sample_random_viewpoint()
ref_image = render_viewpoint(ref_position, ref_orientation)
plt.clf()
fig, axes = plt.subplots(1,1, squeeze=False, num=1)
axes[0,0].imshow(ref_image)
for ax in axes.flatten():
 ax.set_xticks([])
 ax.set_yticks([])

### Interactive cross-view completion using CroCo

In [None]:
reconstruct_unmasked_patches = False

def show_demo(masking_ratio, x, y, z, panorama, elevation):
 R = quaternion.as_rotation_matrix(ref_orientation)
 target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]
 target_orientation = (ref_orientation
 * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) 
 * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))
 
 ref_image = render_viewpoint(ref_position, ref_orientation)
 target_image = render_viewpoint(target_position, target_orientation)

 masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)

 fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)
 axes[0].imshow(ref_image)
 axes[0].set_xlabel("Reference")
 axes[1].imshow(masked_target_image)
 axes[1].set_xlabel("Masked target")
 axes[2].imshow(predicted_image)
 axes[2].set_xlabel("Reconstruction") 
 axes[3].imshow(target_image)
 axes[3].set_xlabel("Target")
 for ax in axes.flatten():
 ax.set_xticks([])
 ax.set_yticks([])

interact(show_demo,
 masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),
 x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),
 y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),
 z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),
 panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),
 elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));