SmilingWolf's picture
Update CLIP-style model
74bd9c8
raw
history blame contribute delete
No virus
1.76 kB
import flax
import jax.numpy as jnp
class Image(flax.linen.Module):
out_units: int = 1024
@flax.linen.compact
def __call__(self, x, training=False):
x = flax.linen.Dropout(0.1)(x, deterministic=not training)
return x
class Text(flax.linen.Module):
out_units: int = 1024
@flax.linen.compact
def __call__(self, x, training=False):
x = flax.linen.Dense(features=self.out_units)(x)
res = flax.linen.silu(x)
res = flax.linen.Dense(features=self.out_units)(res)
res = flax.linen.Dropout(0.1)(res, deterministic=not training)
x = x + res
return x
class CLIP(flax.linen.Module):
out_units: int = 1024
def setup(self):
self.image_enc = Image(self.out_units)
self.text_enc = Text(self.out_units)
self.logit_scale = self.variable(
"params",
"logit_scale",
lambda x: jnp.log(10) * jnp.ones((1,)),
None,
).value
self.logit_bias = self.variable(
"params",
"logit_bias",
lambda x: -10 * jnp.ones((1,)),
None,
).value
@flax.linen.compact
def __call__(self, image, text, training=False):
image_emb = self.image_enc(image, training=training)
text_emb = self.text_enc(text, training=training)
# Normalize
image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)
logits = image_emb @ text_emb.T * jnp.exp(self.logit_scale) + self.logit_bias
return logits
def encode_text(self, text):
text_emb = self.text_enc(text, training=False)
return text_emb