import jax | |
import requests | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import CLIPProcessor, FlaxCLIPModel | |
model = FlaxCLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32") | |
url = "https://huggingface.co/flaviagiammarino/pubmed-clip-vit-base-patch32/resolve/main/scripts/input.jpeg" | |
image = Image.open(requests.get(url, stream=True).raw) | |
text = ["Chest X-Ray", "Brain MRI", "Abdominal CT Scan"] | |
inputs = processor(text=text, images=image, return_tensors="jax", padding=True) | |
probs = jax.nn.softmax(model(**inputs).logits_per_image, axis=-1).flatten() | |
plt.subplots() | |
plt.imshow(image) | |
plt.title("".join([x[0] + ": " + x[1] + "\n" for x in zip(text, [format(prob, ".4%") for prob in probs])])) | |
plt.axis("off") | |
plt.tight_layout() | |
plt.show() |