alexlau commited on
Commit
4ef918b
1 Parent(s): b87426b

fix pose rendering

Browse files
Files changed (1) hide show
  1. jaxnerf/nerf/clip_utils.py +43 -59
jaxnerf/nerf/clip_utils.py CHANGED
@@ -9,38 +9,53 @@ import jax.numpy as jnp
9
  import numpy as np
10
  from transformers import FlaxCLIPModel
11
 
 
 
12
  FLAGS = flags.FLAGS
13
- # import jmp
14
- # my_policy = jmp.Policy(compute_dtype=np.float16,
15
- # param_dtype=np.float16,
16
- # output_dtype=np.float16)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  @partial(jax.jit, static_argnums=[0, 1])
20
- def update_semantic_loss(model, clip_model, rng, state, batch, lr):
21
  # the batch is without shard
22
  random_rays = batch["random_rays"]
23
- #rng, key_0, key_1 = rng
24
  rng, key_0, key_1 = random.split(rng,3)
25
 
26
  def semantic_loss(variables):
27
- # TODO @Alex: (alt) sample less along a ray/ sample on a strided grid (make change on model call)
28
- # TODO @Alex: (alt) apply mixed precision
29
- src_ret = model.apply(variables, key_0, key_1, random_rays, False)
30
- src_image, _, _ = src_ret[-1]
31
  # reshape flat pixel to an image (assume 3 channels & square shape)
32
  w = int(math.sqrt(src_image.shape[0]))
33
- src_image = src_image.reshape([-1, w, w, 3]).transpose(0, 3, 1, 2)
34
- src_image = preprocess_for_CLIP(src_image)
35
- src_embedding = clip_model.get_image_features(pixel_values=src_image)
36
  src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
37
  src_embedding = jnp.array(src_embedding)
38
  target_embedding = batch["embedding"]
39
- sc_loss = 0.5 * FLAGS.sc_loss_mult * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
40
- return sc_loss * 1e-2
41
-
42
- sc_loss, grad = jax.value_and_grad(semantic_loss)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
43
- return sc_loss, grad
44
 
45
  def trans_t(t):
46
  return jnp.array([
@@ -49,39 +64,34 @@ def trans_t(t):
49
  [0, 0, 1, t],
50
  [0, 0, 0, 1]], dtype=jnp.float32)
51
 
52
-
53
  def rot_phi(phi):
54
  return jnp.array([
55
  [1, 0, 0, 0],
56
- [0, jnp.cos(phi), -np.sin(phi), 0],
57
- [0, jnp.sin(phi), jnp.cos(phi), 0],
58
  [0, 0, 0, 1]], dtype=jnp.float32)
59
 
60
-
61
  def rot_theta(th):
62
  return jnp.array([
63
- [np.cos(th), 0, -np.sin(th), 0],
64
  [0, 1, 0, 0],
65
- [np.sin(th), 0, jnp.cos(th), 0],
66
  [0, 0, 0, 1]], dtype=jnp.float32)
67
 
68
-
69
- def pose_spherical(theta, phi, radius):
70
  c2w = trans_t(radius)
71
- c2w = rot_phi(phi / 180. * jnp.pi) @ c2w
72
- c2w = rot_theta(theta / 180. * jnp.pi) @ c2w
73
  c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
74
  return c2w
75
 
76
-
77
  def random_pose(rng, bds):
78
  rng, *rng_inputs = jax.random.split(rng, 3)
79
  radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
80
- theta = random.uniform(rng_inputs[1], minval=0, maxval=2 * jnp.pi)
81
- phi = random.uniform(rng_inputs[1], minval=0, maxval=np.pi / 2)
82
  return pose_spherical(radius, theta, phi)
83
 
84
-
85
  def preprocess_for_CLIP(image):
86
  """
87
  jax-based preprocessing for CLIP
@@ -89,14 +99,12 @@ def preprocess_for_CLIP(image):
89
  return [B, 3, 224, 224]: pre-processed image for CLIP
90
  """
91
  B, D, H, W = image.shape
92
- image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
93
  mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
94
  std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
 
95
  image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
96
  return image
97
 
98
-
99
- # TODO @Alex: VisionModel v.s. original CLIP? (differ by a projection matrix)
100
  def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
101
  if dtype == 'float16':
102
  dtype = jnp.float16
@@ -108,27 +116,3 @@ def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
108
  if model_name is None:
109
  model_name = 'openai/clip-vit-base-patch32'
110
  return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)
111
-
112
-
113
- # def SC_loss(rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l):
114
- # """
115
- # target_emb [1, D]: pre-computed target embedding vector \phi(I)
116
- # source_img [1, 3, H, W]: source image \hat{I}
117
- # l: loss weight lambda
118
- # return: SC_loss
119
- # """
120
- # # _,H,W,D = rays.shape
121
- # rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l = my_policy.cast_to_compute(
122
- # (rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l))
123
- # _, H, W, _ = rays.shape
124
- # source_img = jnp.clip(render_fn(rng_inputs, model, params, None,
125
- # np.reshape(rays, (2, -1, 3)),
126
- # bds[0], bds[1], 1, rand=False),
127
- # 0, 1)
128
- # # source_img = np.clip(render_rays(rng_inputs, model, params, None, np.reshape(rays, (2, -1, 3)), bds[0], bds[1], 1, rand=False), 0, 1)
129
- # source_img = np.reshape(source_img, [1, H, W, 3]).transpose(0, 3, 1, 2)
130
- # source_img = preprocess_for_CLIP(source_img)
131
- # source_emb = CLIP_model.get_image_features(pixel_values=source_img)
132
- # source_emb /= np.linalg.norm(source_emb, axis=-1, keepdims=True)
133
- # return l/2 * (np.sum((source_emb - target_emb) ** 2) / source_emb.shape[0])
134
-
 
9
  import numpy as np
10
  from transformers import FlaxCLIPModel
11
 
12
+ from nerf import utils
13
+
14
  FLAGS = flags.FLAGS
 
 
 
 
15
 
16
+ @partial(jax.jit, static_argnums=[0])
17
+ def semantic_loss(clip_model, src_image, target_embedding):
18
+ src_image = utils.unshard(src_image)
19
+ w = int(math.sqrt(src_image.size//3))
20
+ src_image = src_image.reshape([w, w, 3])
21
+
22
+ src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(src_image,0).transpose(0, 3, 1, 2)))
23
+ src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
24
+ src_embedding = jnp.array(src_embedding)
25
+ sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
26
+ return sc_loss, src_image
27
+
28
+ def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
29
+ random_rays = jax.tree_map(lambda x: utils.shard(x).astype(jnp.float16), batch["random_rays"])
30
+ target_embedding = batch["embedding"].astype(jnp.float16)
31
+ rng, key_0, key_1 = random.split(rng,3)
32
+
33
+ def loss_fn(variables):
34
+ src_image = render_pfn(variables, key_0, key_1, random_rays)
35
+ sc_loss, src_image = semantic_loss(clip_model, src_image, target_embedding)
36
+ return sc_loss * FLAGS.sc_loss_mult, src_image
37
+ (sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0].astype(jnp.float16), state)).optimizer.target)
38
+ return sc_loss, grad, src_image
39
 
40
  @partial(jax.jit, static_argnums=[0, 1])
41
+ def semantic_step_single(model, clip_model, rng, state, batch, lr):
42
  # the batch is without shard
43
  random_rays = batch["random_rays"]
 
44
  rng, key_0, key_1 = random.split(rng,3)
45
 
46
  def semantic_loss(variables):
47
+ src_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
 
 
 
48
  # reshape flat pixel to an image (assume 3 channels & square shape)
49
  w = int(math.sqrt(src_image.shape[0]))
50
+ src_image = src_image.reshape([w, w, 3])
51
+ src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(src_image,0).transpose(0, 3, 1, 2)))
 
52
  src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
53
  src_embedding = jnp.array(src_embedding)
54
  target_embedding = batch["embedding"]
55
+ sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
56
+ return sc_loss * FLAGS.sc_loss_mult, src_image
57
+ (sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
58
+ return sc_loss, grad, src_image
 
59
 
60
  def trans_t(t):
61
  return jnp.array([
 
64
  [0, 0, 1, t],
65
  [0, 0, 0, 1]], dtype=jnp.float32)
66
 
 
67
  def rot_phi(phi):
68
  return jnp.array([
69
  [1, 0, 0, 0],
70
+ [0, jnp.cos(phi), jnp.sin(phi), 0],
71
+ [0,-jnp.sin(phi), jnp.cos(phi), 0],
72
  [0, 0, 0, 1]], dtype=jnp.float32)
73
 
 
74
  def rot_theta(th):
75
  return jnp.array([
76
+ [jnp.cos(th), 0,-jnp.sin(th), 0],
77
  [0, 1, 0, 0],
78
+ [jnp.sin(th), 0, jnp.cos(th), 0],
79
  [0, 0, 0, 1]], dtype=jnp.float32)
80
 
81
+ def pose_spherical(radius, theta, phi):
 
82
  c2w = trans_t(radius)
83
+ c2w = rot_phi(phi) @ c2w
84
+ c2w = rot_theta(theta) @ c2w
85
  c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
86
  return c2w
87
 
 
88
  def random_pose(rng, bds):
89
  rng, *rng_inputs = jax.random.split(rng, 3)
90
  radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
91
+ theta = random.uniform(rng_inputs[1], minval=-jnp.pi, maxval=jnp.pi)
92
+ phi = random.uniform(rng_inputs[1], minval=0, maxval=jnp.pi/2)
93
  return pose_spherical(radius, theta, phi)
94
 
 
95
  def preprocess_for_CLIP(image):
96
  """
97
  jax-based preprocessing for CLIP
 
99
  return [B, 3, 224, 224]: pre-processed image for CLIP
100
  """
101
  B, D, H, W = image.shape
 
102
  mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
103
  std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
104
+ image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
105
  image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
106
  return image
107
 
 
 
108
  def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
109
  if dtype == 'float16':
110
  dtype = jnp.float16
 
116
  if model_name is None:
117
  model_name = 'openai/clip-vit-base-patch32'
118
  return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)