SmilingWolf commited on
Commit
74bd9c8
1 Parent(s): f216079

Update CLIP-style model

Browse files

Moreover:
- add .gitignore
- clean up .gitattributes
- add SigLIP-like model code
- clean up a bit the Artoria example

.gitattributes CHANGED
@@ -34,8 +34,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.index filter=lfs diff=lfs merge=lfs -text
37
-
38
- # Byte-compiled / optimized / DLL files
39
- __pycache__/
40
- *.py[cod]
41
- *$py.class
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.index filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
Models/CLIP.py CHANGED
@@ -28,12 +28,18 @@ class Text(flax.linen.Module):
28
 
29
  class CLIP(flax.linen.Module):
30
  out_units: int = 1024
31
- logit_scale: float = 1.0
32
 
33
  def setup(self):
34
  self.image_enc = Image(self.out_units)
35
  self.text_enc = Text(self.out_units)
36
 
 
 
 
 
 
 
 
37
  @flax.linen.compact
38
  def __call__(self, image, text, training=False):
39
  image_emb = self.image_enc(image, training=training)
@@ -43,8 +49,8 @@ class CLIP(flax.linen.Module):
43
  image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
44
  text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)
45
 
46
- image_sim = self.logit_scale * image_emb @ text_emb.T
47
- text_sim = self.logit_scale * text_emb @ image_emb.T
48
  return image_sim, text_sim
49
 
50
  def encode_text(self, text):
 
28
 
29
  class CLIP(flax.linen.Module):
30
  out_units: int = 1024
 
31
 
32
  def setup(self):
33
  self.image_enc = Image(self.out_units)
34
  self.text_enc = Text(self.out_units)
35
 
36
+ self.logit_scale = self.variable(
37
+ "params",
38
+ "logit_scale",
39
+ lambda x: jnp.log(10) * jnp.ones((1,)),
40
+ None,
41
+ ).value
42
+
43
  @flax.linen.compact
44
  def __call__(self, image, text, training=False):
45
  image_emb = self.image_enc(image, training=training)
 
49
  image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
50
  text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)
51
 
52
+ image_sim = jnp.exp(self.logit_scale) * image_emb @ text_emb.T
53
+ text_sim = jnp.exp(self.logit_scale) * text_emb @ image_emb.T
54
  return image_sim, text_sim
55
 
56
  def encode_text(self, text):
Models/SigLIP.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ class Image(flax.linen.Module):
6
+ out_units: int = 1024
7
+
8
+ @flax.linen.compact
9
+ def __call__(self, x, training=False):
10
+ x = flax.linen.Dropout(0.1)(x, deterministic=not training)
11
+ return x
12
+
13
+
14
+ class Text(flax.linen.Module):
15
+ out_units: int = 1024
16
+
17
+ @flax.linen.compact
18
+ def __call__(self, x, training=False):
19
+ x = flax.linen.Dense(features=self.out_units)(x)
20
+
21
+ res = flax.linen.silu(x)
22
+ res = flax.linen.Dense(features=self.out_units)(res)
23
+ res = flax.linen.Dropout(0.1)(res, deterministic=not training)
24
+
25
+ x = x + res
26
+ return x
27
+
28
+
29
+ class CLIP(flax.linen.Module):
30
+ out_units: int = 1024
31
+
32
+ def setup(self):
33
+ self.image_enc = Image(self.out_units)
34
+ self.text_enc = Text(self.out_units)
35
+
36
+ self.logit_scale = self.variable(
37
+ "params",
38
+ "logit_scale",
39
+ lambda x: jnp.log(10) * jnp.ones((1,)),
40
+ None,
41
+ ).value
42
+ self.logit_bias = self.variable(
43
+ "params",
44
+ "logit_bias",
45
+ lambda x: -10 * jnp.ones((1,)),
46
+ None,
47
+ ).value
48
+
49
+ @flax.linen.compact
50
+ def __call__(self, image, text, training=False):
51
+ image_emb = self.image_enc(image, training=training)
52
+ text_emb = self.text_enc(text, training=training)
53
+
54
+ # Normalize
55
+ image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
56
+ text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)
57
+
58
+ logits = image_emb @ text_emb.T * jnp.exp(self.logit_scale) + self.logit_bias
59
+ return logits
60
+
61
+ def encode_text(self, text):
62
+ text_emb = self.text_enc(text, training=False)
63
+ return text_emb
app.py CHANGED
@@ -243,7 +243,7 @@ def main():
243
  None,
244
  None,
245
  "artoria_pendragon_(fate),solo",
246
- "excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
247
  "CLIP",
248
  ["General", "Sensitive"],
249
  5,
 
243
  None,
244
  None,
245
  "artoria_pendragon_(fate),solo",
246
+ "green_eyes",
247
  "CLIP",
248
  ["General", "Sensitive"],
249
  5,
data/wd-v1-4-convnext-tagger-v2/clip.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e32b62f6bee5e8db4b17a05d605435dcfa24dc99d0eb26582078f2181567031
3
- size 48689306
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:699734f96af810340b397f1a4e30a2ec40771f4288fc5fb861e7437000984e6d
3
+ size 48689338