File size: 4,801 Bytes
19677a1
 
 
 
 
 
 
 
 
 
 
ded0035
4ef918b
19677a1
 
4ef918b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19677a1
 
4ef918b
19677a1
 
 
 
 
4ef918b
19677a1
 
4ef918b
 
19677a1
 
 
4ef918b
 
 
 
19677a1
 
 
 
 
 
 
 
 
 
 
4ef918b
 
19677a1
 
 
 
4ef918b
19677a1
4ef918b
19677a1
 
4ef918b
19677a1
4ef918b
 
19677a1
 
 
 
 
 
4ef918b
 
19677a1
 
 
 
 
 
 
 
 
 
 
4ef918b
19677a1
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import math
from typing import Optional
from absl import flags
from functools import partial

import jax
from jax import random
import jax.numpy as jnp
import numpy as np
from transformers import FlaxCLIPModel

from jaxnerf.nerf import utils

FLAGS = flags.FLAGS

@partial(jax.jit, static_argnums=[0])
def semantic_loss(clip_model, src_image, target_embedding): 
    src_image = utils.unshard(src_image)
    w = int(math.sqrt(src_image.size//3))
    src_image = src_image.reshape([w, w, 3])
 
    src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(src_image,0).transpose(0, 3, 1, 2)))
    src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
    src_embedding = jnp.array(src_embedding)
    sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
    return sc_loss, src_image

def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
    random_rays = jax.tree_map(lambda x: utils.shard(x).astype(jnp.float16), batch["random_rays"])
    target_embedding = batch["embedding"].astype(jnp.float16)
    rng, key_0, key_1 = random.split(rng,3)

    def loss_fn(variables):
        src_image = render_pfn(variables, key_0, key_1, random_rays)
        sc_loss, src_image = semantic_loss(clip_model, src_image, target_embedding)
        return sc_loss * FLAGS.sc_loss_mult, src_image
    (sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0].astype(jnp.float16), state)).optimizer.target)
    return sc_loss, grad, src_image

@partial(jax.jit, static_argnums=[0, 1])
def semantic_step_single(model, clip_model, rng, state, batch, lr):
    # the batch is without shard
    random_rays = batch["random_rays"]
    rng, key_0, key_1 = random.split(rng,3)

    def semantic_loss(variables):
        src_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
        # reshape flat pixel to an image (assume 3 channels & square shape)
        w = int(math.sqrt(src_image.shape[0]))
        src_image = src_image.reshape([w, w, 3])
        src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(src_image,0).transpose(0, 3, 1, 2)))
        src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
        src_embedding = jnp.array(src_embedding)
        target_embedding = batch["embedding"]
        sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
        return sc_loss * FLAGS.sc_loss_mult, src_image
    (sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
    return sc_loss, grad, src_image

def trans_t(t):
    return jnp.array([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, t],
        [0, 0, 0, 1]], dtype=jnp.float32)

def rot_phi(phi):
    return jnp.array([
        [1, 0, 0, 0],
        [0, jnp.cos(phi), jnp.sin(phi), 0],
        [0,-jnp.sin(phi), jnp.cos(phi), 0],
        [0, 0, 0, 1]], dtype=jnp.float32)

def rot_theta(th):
    return jnp.array([
        [jnp.cos(th), 0,-jnp.sin(th), 0],
        [0, 1, 0, 0],
        [jnp.sin(th), 0, jnp.cos(th), 0],
        [0, 0, 0, 1]], dtype=jnp.float32)

def pose_spherical(radius, theta, phi):
    c2w = trans_t(radius)
    c2w = rot_phi(phi) @ c2w
    c2w = rot_theta(theta) @ c2w
    c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
    return c2w

def random_pose(rng, bds):
    rng, *rng_inputs = jax.random.split(rng, 3)
    radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
    theta = random.uniform(rng_inputs[1], minval=-jnp.pi, maxval=jnp.pi)
    phi = random.uniform(rng_inputs[1], minval=0, maxval=jnp.pi/2)
    return pose_spherical(radius, theta, phi)

def preprocess_for_CLIP(image):
    """
    jax-based preprocessing for CLIP
    image  [B, 3, H, W]: batch image
    return [B, 3, 224, 224]: pre-processed image for CLIP
    """
    B, D, H, W = image.shape
    mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
    std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
    image = jax.image.resize(image, (B, D, 224, 224), 'bicubic')  # assume that images have rectangle shape.
    image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
    return image

def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
    if dtype == 'float16':
        dtype = jnp.float16
    elif dtype == 'float32':
        dtype = jnp.float32
    else:
        raise ValueError

    if model_name is None:
        model_name = 'openai/clip-vit-base-patch32'
    return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)