Spaces:
Build error
Build error
import math | |
from typing import Optional | |
from absl import flags | |
from functools import partial | |
import jax | |
from jax import random | |
import jax.numpy as jnp | |
import numpy as np | |
from transformers import FlaxCLIPModel | |
FLAGS = flags.FLAGS | |
# import jmp | |
# my_policy = jmp.Policy(compute_dtype=np.float16, | |
# param_dtype=np.float16, | |
# output_dtype=np.float16) | |
def update_semantic_loss(model, clip_model, rng, state, batch, lr): | |
# the batch is without shard | |
random_rays = batch["random_rays"] | |
#rng, key_0, key_1 = rng | |
rng, key_0, key_1 = random.split(rng,3) | |
def semantic_loss(variables): | |
# TODO @Alex: (alt) sample less along a ray/ sample on a strided grid (make change on model call) | |
# TODO @Alex: (alt) apply mixed precision | |
src_ret = model.apply(variables, key_0, key_1, random_rays, False) | |
src_image, _, _ = src_ret[-1] | |
# reshape flat pixel to an image (assume 3 channels & square shape) | |
w = int(math.sqrt(src_image.shape[0])) | |
src_image = src_image.reshape([-1, w, w, 3]).transpose(0, 3, 1, 2) | |
src_image = preprocess_for_CLIP(src_image) | |
src_embedding = clip_model.get_image_features(pixel_values=src_image) | |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True) | |
src_embedding = jnp.array(src_embedding) | |
target_embedding = batch["embedding"] | |
sc_loss = 0.5 * FLAGS.sc_loss_mult * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0] | |
return sc_loss * 1e-2 | |
sc_loss, grad = jax.value_and_grad(semantic_loss)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target) | |
return sc_loss, grad | |
def trans_t(t): | |
return jnp.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 1, t], | |
[0, 0, 0, 1]], dtype=jnp.float32) | |
def rot_phi(phi): | |
return jnp.array([ | |
[1, 0, 0, 0], | |
[0, jnp.cos(phi), -np.sin(phi), 0], | |
[0, jnp.sin(phi), jnp.cos(phi), 0], | |
[0, 0, 0, 1]], dtype=jnp.float32) | |
def rot_theta(th): | |
return jnp.array([ | |
[np.cos(th), 0, -np.sin(th), 0], | |
[0, 1, 0, 0], | |
[np.sin(th), 0, jnp.cos(th), 0], | |
[0, 0, 0, 1]], dtype=jnp.float32) | |
def pose_spherical(theta, phi, radius): | |
c2w = trans_t(radius) | |
c2w = rot_phi(phi / 180. * jnp.pi) @ c2w | |
c2w = rot_theta(theta / 180. * jnp.pi) @ c2w | |
c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w | |
return c2w | |
def random_pose(rng, bds): | |
rng, *rng_inputs = jax.random.split(rng, 3) | |
radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1]) | |
theta = random.uniform(rng_inputs[1], minval=0, maxval=2 * jnp.pi) | |
phi = random.uniform(rng_inputs[1], minval=0, maxval=np.pi / 2) | |
return pose_spherical(radius, theta, phi) | |
def preprocess_for_CLIP(image): | |
""" | |
jax-based preprocessing for CLIP | |
image [B, 3, H, W]: batch image | |
return [B, 3, 224, 224]: pre-processed image for CLIP | |
""" | |
B, D, H, W = image.shape | |
image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape. | |
mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1) | |
std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1) | |
image = (image - mean.astype(image.dtype)) / std.astype(image.dtype) | |
return image | |
# TODO @Alex: VisionModel v.s. original CLIP? (differ by a projection matrix) | |
def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel: | |
if dtype == 'float16': | |
dtype = jnp.float16 | |
elif dtype == 'float32': | |
dtype = jnp.float32 | |
else: | |
raise ValueError | |
if model_name is None: | |
model_name = 'openai/clip-vit-base-patch32' | |
return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype) | |
# def SC_loss(rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l): | |
# """ | |
# target_emb [1, D]: pre-computed target embedding vector \phi(I) | |
# source_img [1, 3, H, W]: source image \hat{I} | |
# l: loss weight lambda | |
# return: SC_loss | |
# """ | |
# # _,H,W,D = rays.shape | |
# rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l = my_policy.cast_to_compute( | |
# (rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l)) | |
# _, H, W, _ = rays.shape | |
# source_img = jnp.clip(render_fn(rng_inputs, model, params, None, | |
# np.reshape(rays, (2, -1, 3)), | |
# bds[0], bds[1], 1, rand=False), | |
# 0, 1) | |
# # source_img = np.clip(render_rays(rng_inputs, model, params, None, np.reshape(rays, (2, -1, 3)), bds[0], bds[1], 1, rand=False), 0, 1) | |
# source_img = np.reshape(source_img, [1, H, W, 3]).transpose(0, 3, 1, 2) | |
# source_img = preprocess_for_CLIP(source_img) | |
# source_emb = CLIP_model.get_image_features(pixel_values=source_img) | |
# source_emb /= np.linalg.norm(source_emb, axis=-1, keepdims=True) | |
# return l/2 * (np.sum((source_emb - target_emb) ** 2) / source_emb.shape[0]) | |