|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utils for GIVT stage I and II trainers."""
|
|
|
|
from typing import Any
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
|
|
def unbin_depth(
|
|
depth: jax.Array,
|
|
*,
|
|
min_depth: float,
|
|
max_depth: float,
|
|
num_bins: int,
|
|
) -> jax.Array:
|
|
"""Transform a depth map with binned values into a float-valued depth map.
|
|
|
|
Args:
|
|
depth: Depth map whose binned values are encoded in one-hot fashion along
|
|
the last dimension.
|
|
min_depth: Minimum binned depth value.
|
|
max_depth: Maximum value of binned depth.
|
|
num_bins: Number of depth bins.
|
|
|
|
Returns:
|
|
Float-valued depth map.
|
|
"""
|
|
depth = jnp.argmax(depth, axis=-1)
|
|
depth = depth.astype(jnp.float32) + 0.5
|
|
depth /= num_bins
|
|
return depth * (max_depth - min_depth) + min_depth
|
|
|
|
|
|
def get_local_rng(
|
|
seed: int | jax.Array,
|
|
batch: Any,
|
|
) -> jax.Array:
|
|
"""Generate a per-image seed based on the image id or the image values.
|
|
|
|
Args:
|
|
seed: Random seed from which per-image seeds should be derived.
|
|
batch: Pytree containing a batch of images (key "image") and optionally
|
|
image ids (key "image/id").
|
|
|
|
Returns:
|
|
Array containing per-image ids.
|
|
"""
|
|
fake_id = None
|
|
if "image" in batch:
|
|
fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32)
|
|
return jax.lax.scan(
|
|
lambda k, x: (jax.random.fold_in(k, x), None),
|
|
jax.random.PRNGKey(seed),
|
|
batch.get("image/id", fake_id),
|
|
)[0]
|
|
|
|
|