File size: 1,230 Bytes
60c6529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json
from clip_for_image_classification import FlaxCLIPForImageClassification
from PIL import Image
import jax
import numpy as np
from transformers import CLIPImageProcessor
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

model = FlaxCLIPForImageClassification.from_pretrained("Thouph/clip-vit-l-224-patch14-datacomp-image-classification")
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K")
image = Image.open("/your/image/here.jpg")
inputs = image_processor(images=image, return_tensors="jax")
outputs = model(**inputs)
logits = outputs.logits
probabilities = jax.nn.sigmoid(logits)
probabilities = np.asarray(probabilities).copy()


def topk_by_sort(input, k, axis=None, ascending=False):
    if not ascending:
        input *= -1
    ind = np.argsort(input, axis=axis)
    ind = np.take(ind, np.arange(k), axis=axis)
    if not ascending:
        input *= -1
    val = np.take_along_axis(input, ind, axis=axis)
    return ind, val


indices, values = topk_by_sort(probabilities, 100)

with open("7748tags.json", "r") as file:
    allowed_tags = json.load(file)

for index, value in zip(indices, values, strict=True):
    print(allowed_tags[index], value)