import json from clip_for_image_classification import FlaxCLIPForImageClassification from PIL import Image import jax import numpy as np from transformers import CLIPImageProcessor import os os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" model = FlaxCLIPForImageClassification.from_pretrained("Thouph/clip-vit-l-224-patch14-datacomp-image-classification") image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K") image = Image.open("/your/image/here.jpg") inputs = image_processor(images=image, return_tensors="jax") outputs = model(**inputs) logits = outputs.logits probabilities = jax.nn.sigmoid(logits) probabilities = np.asarray(probabilities).copy() def topk_by_sort(input, k, axis=None, ascending=False): if not ascending: input *= -1 ind = np.argsort(input, axis=axis) ind = np.take(ind, np.arange(k), axis=axis) if not ascending: input *= -1 val = np.take_along_axis(input, ind, axis=axis) return ind, val indices, values = topk_by_sort(probabilities, 100) with open("7748tags.json", "r") as file: allowed_tags = json.load(file) for index, value in zip(indices, values, strict=True): print(allowed_tags[index], value)