Upload 3 files
Browse files- 7748tags.json +0 -0
- clip_for_image_classification.py +56 -0
- inference.py +38 -0
7748tags.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
clip_for_image_classification.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPVisionConfig, FlaxCLIPVisionPreTrainedModel
|
2 |
+
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
|
3 |
+
import jax.numpy as jnp
|
4 |
+
from flax import linen as nn
|
5 |
+
import jax
|
6 |
+
from transformers.modeling_flax_outputs import FlaxSequenceClassifierOutput
|
7 |
+
|
8 |
+
|
9 |
+
class FlaxCLIPForImageClassificationModule(nn.Module):
|
10 |
+
config: CLIPVisionConfig
|
11 |
+
dtype: jnp.dtype = jnp.float32
|
12 |
+
|
13 |
+
def setup(self):
|
14 |
+
self.vit = FlaxCLIPVisionModule(config=self.config, dtype=self.dtype)
|
15 |
+
self.classifier = nn.Dense(
|
16 |
+
self.config.num_labels,
|
17 |
+
dtype=self.dtype,
|
18 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
19 |
+
self.config.initializer_range ** 2, "fan_in", "truncated_normal"
|
20 |
+
),
|
21 |
+
)
|
22 |
+
|
23 |
+
def __call__(
|
24 |
+
self,
|
25 |
+
pixel_values=None,
|
26 |
+
deterministic: bool = True,
|
27 |
+
output_attentions=None,
|
28 |
+
output_hidden_states=None,
|
29 |
+
return_dict=None,
|
30 |
+
):
|
31 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
32 |
+
|
33 |
+
outputs = self.vit(
|
34 |
+
pixel_values,
|
35 |
+
deterministic=deterministic,
|
36 |
+
output_attentions=output_attentions,
|
37 |
+
output_hidden_states=output_hidden_states,
|
38 |
+
return_dict=return_dict,
|
39 |
+
)
|
40 |
+
|
41 |
+
hidden_states = outputs[0]
|
42 |
+
logits = self.classifier(hidden_states[:, 0, :])
|
43 |
+
|
44 |
+
if not return_dict:
|
45 |
+
output = (logits,) + outputs[2:]
|
46 |
+
return output
|
47 |
+
|
48 |
+
return FlaxSequenceClassifierOutput(
|
49 |
+
logits=logits,
|
50 |
+
hidden_states=outputs.hidden_states,
|
51 |
+
attentions=outputs.attentions,
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
class FlaxCLIPForImageClassification(FlaxCLIPVisionPreTrainedModel):
|
56 |
+
module_class = FlaxCLIPForImageClassificationModule
|
inference.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from clip_for_image_classification import FlaxCLIPForImageClassification
|
3 |
+
from PIL import Image
|
4 |
+
import jax
|
5 |
+
import numpy as np
|
6 |
+
from transformers import CLIPImageProcessor
|
7 |
+
import os
|
8 |
+
|
9 |
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
10 |
+
|
11 |
+
model = FlaxCLIPForImageClassification.from_pretrained("Thouph/clip-vit-l-224-patch14-datacomp-image-classification")
|
12 |
+
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K")
|
13 |
+
image = Image.open("/your/image/here.jpg")
|
14 |
+
inputs = image_processor(images=image, return_tensors="jax")
|
15 |
+
outputs = model(**inputs)
|
16 |
+
logits = outputs.logits
|
17 |
+
probabilities = jax.nn.sigmoid(logits)
|
18 |
+
probabilities = np.asarray(probabilities).copy()
|
19 |
+
|
20 |
+
|
21 |
+
def topk_by_sort(input, k, axis=None, ascending=False):
|
22 |
+
if not ascending:
|
23 |
+
input *= -1
|
24 |
+
ind = np.argsort(input, axis=axis)
|
25 |
+
ind = np.take(ind, np.arange(k), axis=axis)
|
26 |
+
if not ascending:
|
27 |
+
input *= -1
|
28 |
+
val = np.take_along_axis(input, ind, axis=axis)
|
29 |
+
return ind, val
|
30 |
+
|
31 |
+
|
32 |
+
indices, values = topk_by_sort(probabilities, 100)
|
33 |
+
|
34 |
+
with open("7748tags.json", "r") as file:
|
35 |
+
allowed_tags = json.load(file)
|
36 |
+
|
37 |
+
for index, value in zip(indices, values, strict=True):
|
38 |
+
print(allowed_tags[index], value)
|