Thouph commited on
Commit
60c6529
1 Parent(s): 51636ce

Upload 3 files

Browse files
Files changed (3) hide show
  1. 7748tags.json +0 -0
  2. clip_for_image_classification.py +56 -0
  3. inference.py +38 -0
7748tags.json ADDED
The diff for this file is too large to render. See raw diff
 
clip_for_image_classification.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPVisionConfig, FlaxCLIPVisionPreTrainedModel
2
+ from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
3
+ import jax.numpy as jnp
4
+ from flax import linen as nn
5
+ import jax
6
+ from transformers.modeling_flax_outputs import FlaxSequenceClassifierOutput
7
+
8
+
9
+ class FlaxCLIPForImageClassificationModule(nn.Module):
10
+ config: CLIPVisionConfig
11
+ dtype: jnp.dtype = jnp.float32
12
+
13
+ def setup(self):
14
+ self.vit = FlaxCLIPVisionModule(config=self.config, dtype=self.dtype)
15
+ self.classifier = nn.Dense(
16
+ self.config.num_labels,
17
+ dtype=self.dtype,
18
+ kernel_init=jax.nn.initializers.variance_scaling(
19
+ self.config.initializer_range ** 2, "fan_in", "truncated_normal"
20
+ ),
21
+ )
22
+
23
+ def __call__(
24
+ self,
25
+ pixel_values=None,
26
+ deterministic: bool = True,
27
+ output_attentions=None,
28
+ output_hidden_states=None,
29
+ return_dict=None,
30
+ ):
31
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
32
+
33
+ outputs = self.vit(
34
+ pixel_values,
35
+ deterministic=deterministic,
36
+ output_attentions=output_attentions,
37
+ output_hidden_states=output_hidden_states,
38
+ return_dict=return_dict,
39
+ )
40
+
41
+ hidden_states = outputs[0]
42
+ logits = self.classifier(hidden_states[:, 0, :])
43
+
44
+ if not return_dict:
45
+ output = (logits,) + outputs[2:]
46
+ return output
47
+
48
+ return FlaxSequenceClassifierOutput(
49
+ logits=logits,
50
+ hidden_states=outputs.hidden_states,
51
+ attentions=outputs.attentions,
52
+ )
53
+
54
+
55
+ class FlaxCLIPForImageClassification(FlaxCLIPVisionPreTrainedModel):
56
+ module_class = FlaxCLIPForImageClassificationModule
inference.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from clip_for_image_classification import FlaxCLIPForImageClassification
3
+ from PIL import Image
4
+ import jax
5
+ import numpy as np
6
+ from transformers import CLIPImageProcessor
7
+ import os
8
+
9
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
10
+
11
+ model = FlaxCLIPForImageClassification.from_pretrained("Thouph/clip-vit-l-224-patch14-datacomp-image-classification")
12
+ image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K")
13
+ image = Image.open("/your/image/here.jpg")
14
+ inputs = image_processor(images=image, return_tensors="jax")
15
+ outputs = model(**inputs)
16
+ logits = outputs.logits
17
+ probabilities = jax.nn.sigmoid(logits)
18
+ probabilities = np.asarray(probabilities).copy()
19
+
20
+
21
+ def topk_by_sort(input, k, axis=None, ascending=False):
22
+ if not ascending:
23
+ input *= -1
24
+ ind = np.argsort(input, axis=axis)
25
+ ind = np.take(ind, np.arange(k), axis=axis)
26
+ if not ascending:
27
+ input *= -1
28
+ val = np.take_along_axis(input, ind, axis=axis)
29
+ return ind, val
30
+
31
+
32
+ indices, values = topk_by_sort(probabilities, 100)
33
+
34
+ with open("7748tags.json", "r") as file:
35
+ allowed_tags = json.load(file)
36
+
37
+ for index, value in zip(indices, values, strict=True):
38
+ print(allowed_tags[index], value)