SmilingWolf's picture
Push some files I forgot + new model weights
7f89261
raw
history blame
1.49 kB
import flax
import jax.numpy as jnp
class Image(flax.linen.Module):
out_units: int = 1024
@flax.linen.compact
def __call__(self, x, training=False):
x = flax.linen.Dropout(0.1)(x, deterministic=not training)
return x
class Text(flax.linen.Module):
out_units: int = 1024
@flax.linen.compact
def __call__(self, x, training=False):
x = flax.linen.Dense(features=self.out_units)(x)
res = flax.linen.silu(x)
res = flax.linen.Dense(features=self.out_units)(res)
res = flax.linen.Dropout(0.1)(res, deterministic=not training)
x = x + res
return x
class CLIP(flax.linen.Module):
out_units: int = 1024
logit_scale: float = 1.0
def setup(self):
self.image_enc = Image(self.out_units)
self.text_enc = Text(self.out_units)
@flax.linen.compact
def __call__(self, image, text, training=False):
image_emb = self.image_enc(image, training=training)
text_emb = self.text_enc(text, training=training)
# Normalize
image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)
image_sim = self.logit_scale * image_emb @ text_emb.T
text_sim = self.logit_scale * text_emb @ image_emb.T
return image_sim, text_sim
def encode_text(self, text):
text_emb = self.text_enc(text, training=False)
return text_emb