from PIL import Image from gem import create_gem_model, get_gem_img_transform, visualize, available_models import torch import requests print(available_models()) device = "cuda:0" if torch.cuda.is_available() else "cpu" model_name = 'ViT-B-16-quickgelu' pretrained = 'metaclip_400m' gem_model = create_gem_model(model_name=model_name, pretrained=pretrained, device=device) gem_model.eval() ########################### # Single Image ########################### url = "http://images.cocodataset.org/val2017/000000039769.jpg" # cat & remote control text = ['remote control', 'cat'] # image_path = 'path/to/image' #, <-- uncomment to use path image_pil = Image.open(requests.get(url, stream=True).raw) # image_pil = Image.open(image_path) # <-- uncomment to use path gem_img_transform = get_gem_img_transform() image = gem_img_transform(image_pil).unsqueeze(0).to(device) with torch.no_grad(): logits = gem_model(image, text) visualize(image, text, logits) print(logits.shape) # torch.Size([1, 2, 448, 448]) # visualize(image_pil, text, logits) # <-- works with torch.Tensor and PIL.Image ########################### # Batch of Images ########################### urls = [ "http://images.cocodataset.org/val2017/000000039769.jpg", "https://cdn.vietnambiz.vn/171464876016439296/2021/7/11/headshots16170695297430-1626006880779826347793.jpg", "https://preview.redd.it/do-you-think-joker-should-be-unpredictable-enough-to-put-up-v0-6a2ax4ngtlaa1.jpg?auto=webp&s=f8762e6a1b40642bcae5900bac184fc597131503", ] texts = [ ['remote control', 'cat'], ['elon musk', 'mark zuckerberg', 'jeff bezos', 'bill gates'], ['batman', 'joker', 'shoe', 'belt', 'purple suit'], ] # note that the number of prompt per image can be different # download images + convert to PIL.Image images_pil = [Image.open(requests.get(url, stream=True).raw) for url in urls] images = torch.stack([gem_img_transform(img) for img in images_pil]).to(device) with torch.no_grad(): # return list with logits of size [1, num_prompt, W, H] logits_list = gem_model.batched_forward(images, texts) print(logits_list[0].shape) # torch.Size([2, 448, 448]) print(logits_list[1].shape) # torch.Size([4, 448, 448]) print(logits_list[2].shape) # torch.Size([5, 448, 448]) for i, _logits in enumerate(logits_list): visualize(images[i], texts[i], _logits) # (optional visualization)