Thouph's picture
Upload 3 files
60c6529
import json
from clip_for_image_classification import FlaxCLIPForImageClassification
from PIL import Image
import jax
import numpy as np
from transformers import CLIPImageProcessor
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
model = FlaxCLIPForImageClassification.from_pretrained("Thouph/clip-vit-l-224-patch14-datacomp-image-classification")
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K")
image = Image.open("/your/image/here.jpg")
inputs = image_processor(images=image, return_tensors="jax")
outputs = model(**inputs)
logits = outputs.logits
probabilities = jax.nn.sigmoid(logits)
probabilities = np.asarray(probabilities).copy()
def topk_by_sort(input, k, axis=None, ascending=False):
if not ascending:
input *= -1
ind = np.argsort(input, axis=axis)
ind = np.take(ind, np.arange(k), axis=axis)
if not ascending:
input *= -1
val = np.take_along_axis(input, ind, axis=axis)
return ind, val
indices, values = topk_by_sort(probabilities, 100)
with open("7748tags.json", "r") as file:
allowed_tags = json.load(file)
for index, value in zip(indices, values, strict=True):
print(allowed_tags[index], value)