DietNerf-Demo / jaxnerf /nerf /clip_utils.py
alexlau's picture
first deploy demo
19677a1
raw
history blame
No virus
5.18 kB
import math
from typing import Optional
from absl import flags
from functools import partial
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
from transformers import FlaxCLIPModel
FLAGS = flags.FLAGS
# import jmp
# my_policy = jmp.Policy(compute_dtype=np.float16,
# param_dtype=np.float16,
# output_dtype=np.float16)
@partial(jax.jit, static_argnums=[0, 1])
def update_semantic_loss(model, clip_model, rng, state, batch, lr):
# the batch is without shard
random_rays = batch["random_rays"]
#rng, key_0, key_1 = rng
rng, key_0, key_1 = random.split(rng,3)
def semantic_loss(variables):
# TODO @Alex: (alt) sample less along a ray/ sample on a strided grid (make change on model call)
# TODO @Alex: (alt) apply mixed precision
src_ret = model.apply(variables, key_0, key_1, random_rays, False)
src_image, _, _ = src_ret[-1]
# reshape flat pixel to an image (assume 3 channels & square shape)
w = int(math.sqrt(src_image.shape[0]))
src_image = src_image.reshape([-1, w, w, 3]).transpose(0, 3, 1, 2)
src_image = preprocess_for_CLIP(src_image)
src_embedding = clip_model.get_image_features(pixel_values=src_image)
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
src_embedding = jnp.array(src_embedding)
target_embedding = batch["embedding"]
sc_loss = 0.5 * FLAGS.sc_loss_mult * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
return sc_loss * 1e-2
sc_loss, grad = jax.value_and_grad(semantic_loss)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
return sc_loss, grad
def trans_t(t):
return jnp.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, t],
[0, 0, 0, 1]], dtype=jnp.float32)
def rot_phi(phi):
return jnp.array([
[1, 0, 0, 0],
[0, jnp.cos(phi), -np.sin(phi), 0],
[0, jnp.sin(phi), jnp.cos(phi), 0],
[0, 0, 0, 1]], dtype=jnp.float32)
def rot_theta(th):
return jnp.array([
[np.cos(th), 0, -np.sin(th), 0],
[0, 1, 0, 0],
[np.sin(th), 0, jnp.cos(th), 0],
[0, 0, 0, 1]], dtype=jnp.float32)
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi / 180. * jnp.pi) @ c2w
c2w = rot_theta(theta / 180. * jnp.pi) @ c2w
c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
return c2w
def random_pose(rng, bds):
rng, *rng_inputs = jax.random.split(rng, 3)
radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
theta = random.uniform(rng_inputs[1], minval=0, maxval=2 * jnp.pi)
phi = random.uniform(rng_inputs[1], minval=0, maxval=np.pi / 2)
return pose_spherical(radius, theta, phi)
def preprocess_for_CLIP(image):
"""
jax-based preprocessing for CLIP
image [B, 3, H, W]: batch image
return [B, 3, 224, 224]: pre-processed image for CLIP
"""
B, D, H, W = image.shape
image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
return image
# TODO @Alex: VisionModel v.s. original CLIP? (differ by a projection matrix)
def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
if dtype == 'float16':
dtype = jnp.float16
elif dtype == 'float32':
dtype = jnp.float32
else:
raise ValueError
if model_name is None:
model_name = 'openai/clip-vit-base-patch32'
return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)
# def SC_loss(rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l):
# """
# target_emb [1, D]: pre-computed target embedding vector \phi(I)
# source_img [1, 3, H, W]: source image \hat{I}
# l: loss weight lambda
# return: SC_loss
# """
# # _,H,W,D = rays.shape
# rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l = my_policy.cast_to_compute(
# (rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l))
# _, H, W, _ = rays.shape
# source_img = jnp.clip(render_fn(rng_inputs, model, params, None,
# np.reshape(rays, (2, -1, 3)),
# bds[0], bds[1], 1, rand=False),
# 0, 1)
# # source_img = np.clip(render_rays(rng_inputs, model, params, None, np.reshape(rays, (2, -1, 3)), bds[0], bds[1], 1, rand=False), 0, 1)
# source_img = np.reshape(source_img, [1, H, W, 3]).transpose(0, 3, 1, 2)
# source_img = preprocess_for_CLIP(source_img)
# source_emb = CLIP_model.get_image_features(pixel_values=source_img)
# source_emb /= np.linalg.norm(source_emb, axis=-1, keepdims=True)
# return l/2 * (np.sum((source_emb - target_emb) ** 2) / source_emb.shape[0])