WalidBouss's picture
Initial commit :tada:
from PIL import Image
from gem import create_gem_model, get_gem_img_transform, visualize, available_models
import torch
import requests
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = 'ViT-B-16-quickgelu'
pretrained = 'metaclip_400m'
gem_model = create_gem_model(model_name=model_name, pretrained=pretrained, device=device)
# Single Image
url = "" # cat & remote control
text = ['remote control', 'cat']
# image_path = 'path/to/image' #, <-- uncomment to use path
image_pil =, stream=True).raw)
# image_pil = # <-- uncomment to use path
gem_img_transform = get_gem_img_transform()
image = gem_img_transform(image_pil).unsqueeze(0).to(device)
with torch.no_grad():
logits = gem_model(image, text)
visualize(image, text, logits)
print(logits.shape) # torch.Size([1, 2, 448, 448])
# visualize(image_pil, text, logits) # <-- works with torch.Tensor and PIL.Image
# Batch of Images
urls = [
texts = [
['remote control', 'cat'],
['elon musk', 'mark zuckerberg', 'jeff bezos', 'bill gates'],
['batman', 'joker', 'shoe', 'belt', 'purple suit'],
] # note that the number of prompt per image can be different
# download images + convert to PIL.Image
images_pil = [, stream=True).raw) for url in urls]
images = torch.stack([gem_img_transform(img) for img in images_pil]).to(device)
with torch.no_grad():
# return list with logits of size [1, num_prompt, W, H]
logits_list = gem_model.batched_forward(images, texts)
print(logits_list[0].shape) # torch.Size([2, 448, 448])
print(logits_list[1].shape) # torch.Size([4, 448, 448])
print(logits_list[2].shape) # torch.Size([5, 448, 448])
for i, _logits in enumerate(logits_list):
visualize(images[i], texts[i], _logits) # (optional visualization)