File size: 1,658 Bytes
7f89261 74bd9c8 7f89261 74bd9c8 7f89261 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import flax
import jax.numpy as jnp
class Image(flax.linen.Module):
out_units: int = 1024
@flax.linen.compact
def __call__(self, x, training=False):
x = flax.linen.Dropout(0.1)(x, deterministic=not training)
return x
class Text(flax.linen.Module):
out_units: int = 1024
@flax.linen.compact
def __call__(self, x, training=False):
x = flax.linen.Dense(features=self.out_units)(x)
res = flax.linen.silu(x)
res = flax.linen.Dense(features=self.out_units)(res)
res = flax.linen.Dropout(0.1)(res, deterministic=not training)
x = x + res
return x
class CLIP(flax.linen.Module):
out_units: int = 1024
def setup(self):
self.image_enc = Image(self.out_units)
self.text_enc = Text(self.out_units)
self.logit_scale = self.variable(
"params",
"logit_scale",
lambda x: jnp.log(10) * jnp.ones((1,)),
None,
).value
@flax.linen.compact
def __call__(self, image, text, training=False):
image_emb = self.image_enc(image, training=training)
text_emb = self.text_enc(text, training=training)
# Normalize
image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)
image_sim = jnp.exp(self.logit_scale) * image_emb @ text_emb.T
text_sim = jnp.exp(self.logit_scale) * text_emb @ image_emb.T
return image_sim, text_sim
def encode_text(self, text):
text_emb = self.text_enc(text, training=False)
return text_emb
|