|
import flax |
|
import jax.numpy as jnp |
|
|
|
|
|
class Image(flax.linen.Module): |
|
out_units: int = 1024 |
|
|
|
@flax.linen.compact |
|
def __call__(self, x, training=False): |
|
x = flax.linen.Dropout(0.1)(x, deterministic=not training) |
|
return x |
|
|
|
|
|
class Text(flax.linen.Module): |
|
out_units: int = 1024 |
|
|
|
@flax.linen.compact |
|
def __call__(self, x, training=False): |
|
x = flax.linen.Dense(features=self.out_units)(x) |
|
|
|
res = flax.linen.silu(x) |
|
res = flax.linen.Dense(features=self.out_units)(res) |
|
res = flax.linen.Dropout(0.1)(res, deterministic=not training) |
|
|
|
x = x + res |
|
return x |
|
|
|
|
|
class CLIP(flax.linen.Module): |
|
out_units: int = 1024 |
|
logit_scale: float = 1.0 |
|
|
|
def setup(self): |
|
self.image_enc = Image(self.out_units) |
|
self.text_enc = Text(self.out_units) |
|
|
|
@flax.linen.compact |
|
def __call__(self, image, text, training=False): |
|
image_emb = self.image_enc(image, training=training) |
|
text_emb = self.text_enc(text, training=training) |
|
|
|
|
|
image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True) |
|
text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True) |
|
|
|
image_sim = self.logit_scale * image_emb @ text_emb.T |
|
text_sim = self.logit_scale * text_emb @ image_emb.T |
|
return image_sim, text_sim |
|
|
|
def encode_text(self, text): |
|
text_emb = self.text_enc(text, training=False) |
|
return text_emb |
|
|