SmilingWolf commited on
Commit
7f89261
1 Parent(s): 4a7ba8d

Push some files I forgot + new model weights

Browse files
Models/CLIP.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ class Image(flax.linen.Module):
6
+ out_units: int = 1024
7
+
8
+ @flax.linen.compact
9
+ def __call__(self, x, training=False):
10
+ x = flax.linen.Dropout(0.1)(x, deterministic=not training)
11
+ return x
12
+
13
+
14
+ class Text(flax.linen.Module):
15
+ out_units: int = 1024
16
+
17
+ @flax.linen.compact
18
+ def __call__(self, x, training=False):
19
+ x = flax.linen.Dense(features=self.out_units)(x)
20
+
21
+ res = flax.linen.silu(x)
22
+ res = flax.linen.Dense(features=self.out_units)(res)
23
+ res = flax.linen.Dropout(0.1)(res, deterministic=not training)
24
+
25
+ x = x + res
26
+ return x
27
+
28
+
29
+ class CLIP(flax.linen.Module):
30
+ out_units: int = 1024
31
+ logit_scale: float = 1.0
32
+
33
+ def setup(self):
34
+ self.image_enc = Image(self.out_units)
35
+ self.text_enc = Text(self.out_units)
36
+
37
+ @flax.linen.compact
38
+ def __call__(self, image, text, training=False):
39
+ image_emb = self.image_enc(image, training=training)
40
+ text_emb = self.text_enc(text, training=training)
41
+
42
+ # Normalize
43
+ image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
44
+ text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)
45
+
46
+ image_sim = self.logit_scale * image_emb @ text_emb.T
47
+ text_sim = self.logit_scale * text_emb @ image_emb.T
48
+ return image_sim, text_sim
49
+
50
+ def encode_text(self, text):
51
+ text_emb = self.text_enc(text, training=False)
52
+ return text_emb
data/wd-v1-4-convnext-tagger-v2/clip.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3be3b97824313f01d9f1d74c43e441199b7ea485f5698d2008739f34c3e41200
3
  size 48689306
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e32b62f6bee5e8db4b17a05d605435dcfa24dc99d0eb26582078f2181567031
3
  size 48689306