File size: 5,175 Bytes
19677a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import math
from typing import Optional
from absl import flags
from functools import partial

import jax
from jax import random
import jax.numpy as jnp
import numpy as np
from transformers import FlaxCLIPModel

FLAGS = flags.FLAGS
# import jmp
# my_policy = jmp.Policy(compute_dtype=np.float16,
#                        param_dtype=np.float16,
#                        output_dtype=np.float16)


@partial(jax.jit, static_argnums=[0, 1])
def update_semantic_loss(model, clip_model, rng, state, batch, lr):
    # the batch is without shard
    random_rays = batch["random_rays"]
    #rng, key_0, key_1 = rng
    rng, key_0, key_1 = random.split(rng,3)

    def semantic_loss(variables):
        # TODO @Alex: (alt) sample less along a ray/ sample on a strided grid (make change on model call)
        # TODO @Alex: (alt) apply mixed precision
        src_ret = model.apply(variables, key_0, key_1, random_rays, False)
        src_image, _, _ = src_ret[-1]
        # reshape flat pixel to an image (assume 3 channels & square shape)
        w = int(math.sqrt(src_image.shape[0]))
        src_image = src_image.reshape([-1, w, w, 3]).transpose(0, 3, 1, 2)
        src_image = preprocess_for_CLIP(src_image)
        src_embedding = clip_model.get_image_features(pixel_values=src_image)
        src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
        src_embedding = jnp.array(src_embedding)
        target_embedding = batch["embedding"]
        sc_loss = 0.5 * FLAGS.sc_loss_mult * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
        return sc_loss * 1e-2

    sc_loss, grad = jax.value_and_grad(semantic_loss)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
    return sc_loss, grad

def trans_t(t):
    return jnp.array([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, t],
        [0, 0, 0, 1]], dtype=jnp.float32)


def rot_phi(phi):
    return jnp.array([
        [1, 0, 0, 0],
        [0, jnp.cos(phi), -np.sin(phi), 0],
        [0, jnp.sin(phi), jnp.cos(phi), 0],
        [0, 0, 0, 1]], dtype=jnp.float32)


def rot_theta(th):
    return jnp.array([
        [np.cos(th), 0, -np.sin(th), 0],
        [0, 1, 0, 0],
        [np.sin(th), 0, jnp.cos(th), 0],
        [0, 0, 0, 1]], dtype=jnp.float32)


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * jnp.pi) @ c2w
    c2w = rot_theta(theta / 180. * jnp.pi) @ c2w
    c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
    return c2w


def random_pose(rng, bds):
    rng, *rng_inputs = jax.random.split(rng, 3)
    radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
    theta = random.uniform(rng_inputs[1], minval=0, maxval=2 * jnp.pi)
    phi = random.uniform(rng_inputs[1], minval=0, maxval=np.pi / 2)
    return pose_spherical(radius, theta, phi)


def preprocess_for_CLIP(image):
    """
    jax-based preprocessing for CLIP
    image  [B, 3, H, W]: batch image
    return [B, 3, 224, 224]: pre-processed image for CLIP
    """
    B, D, H, W = image.shape
    image = jax.image.resize(image, (B, D, 224, 224), 'bicubic')  # assume that images have rectangle shape.
    mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
    std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
    image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
    return image


# TODO @Alex: VisionModel v.s. original CLIP? (differ by a projection matrix)
def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
    if dtype == 'float16':
        dtype = jnp.float16
    elif dtype == 'float32':
        dtype = jnp.float32
    else:
        raise ValueError

    if model_name is None:
        model_name = 'openai/clip-vit-base-patch32'
    return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)


# def SC_loss(rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l):
#     """
#     target_emb [1, D]: pre-computed target embedding vector \phi(I)
#     source_img [1, 3, H, W]: source image \hat{I}
#     l: loss weight lambda
#     return: SC_loss
#     """
#     # _,H,W,D = rays.shape
#     rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l = my_policy.cast_to_compute(
#         (rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l))
#     _, H, W, _ = rays.shape
#     source_img = jnp.clip(render_fn(rng_inputs, model, params, None,
#                                    np.reshape(rays, (2, -1, 3)),
#                                    bds[0], bds[1], 1, rand=False),
#                          0, 1)
#     # source_img = np.clip(render_rays(rng_inputs, model, params, None, np.reshape(rays, (2, -1, 3)), bds[0], bds[1], 1, rand=False), 0, 1)
#     source_img = np.reshape(source_img, [1, H, W, 3]).transpose(0, 3, 1, 2)
#     source_img = preprocess_for_CLIP(source_img)
#     source_emb = CLIP_model.get_image_features(pixel_values=source_img)
#     source_emb /= np.linalg.norm(source_emb, axis=-1, keepdims=True)
#     return l/2 * (np.sum((source_emb - target_emb) ** 2) / source_emb.shape[0])