|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utils for GIVT stage I and II trainers.""" |
|
|
|
from typing import Any |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
def unbin_depth( |
|
depth: jax.Array, |
|
*, |
|
min_depth: float, |
|
max_depth: float, |
|
num_bins: int, |
|
) -> jax.Array: |
|
"""Transform a depth map with binned values into a float-valued depth map. |
|
|
|
Args: |
|
depth: Depth map whose binned values are encoded in one-hot fashion along |
|
the last dimension. |
|
min_depth: Minimum binned depth value. |
|
max_depth: Maximum value of binned depth. |
|
num_bins: Number of depth bins. |
|
|
|
Returns: |
|
Float-valued depth map. |
|
""" |
|
depth = jnp.argmax(depth, axis=-1) |
|
depth = depth.astype(jnp.float32) + 0.5 |
|
depth /= num_bins |
|
return depth * (max_depth - min_depth) + min_depth |
|
|
|
|
|
def get_local_rng( |
|
seed: int | jax.Array, |
|
batch: Any, |
|
) -> jax.Array: |
|
"""Generate a per-image seed based on the image id or the image values. |
|
|
|
Args: |
|
seed: Random seed from which per-image seeds should be derived. |
|
batch: Pytree containing a batch of images (key "image") and optionally |
|
image ids (key "image/id"). |
|
|
|
Returns: |
|
Array containing per-image ids. |
|
""" |
|
fake_id = None |
|
if "image" in batch: |
|
fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32) |
|
return jax.lax.scan( |
|
lambda k, x: (jax.random.fold_in(k, x), None), |
|
jax.random.PRNGKey(seed), |
|
batch.get("image/id", fake_id), |
|
)[0] |
|
|
|
|