Spaces:
Runtime error
Runtime error
from PIL import Image | |
from gem import create_gem_model, get_gem_img_transform, visualize, available_models | |
import torch | |
import requests | |
print(available_models()) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
model_name = 'ViT-B-16-quickgelu' | |
pretrained = 'metaclip_400m' | |
gem_model = create_gem_model(model_name=model_name, pretrained=pretrained, device=device) | |
gem_model.eval() | |
########################### | |
# Single Image | |
########################### | |
url = "http://images.cocodataset.org/val2017/000000039769.jpg" # cat & remote control | |
text = ['remote control', 'cat'] | |
# image_path = 'path/to/image' #, <-- uncomment to use path | |
image_pil = Image.open(requests.get(url, stream=True).raw) | |
# image_pil = Image.open(image_path) # <-- uncomment to use path | |
gem_img_transform = get_gem_img_transform() | |
image = gem_img_transform(image_pil).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
logits = gem_model(image, text) | |
visualize(image, text, logits) | |
print(logits.shape) # torch.Size([1, 2, 448, 448]) | |
# visualize(image_pil, text, logits) # <-- works with torch.Tensor and PIL.Image | |
########################### | |
# Batch of Images | |
########################### | |
urls = [ | |
"http://images.cocodataset.org/val2017/000000039769.jpg", | |
"https://cdn.vietnambiz.vn/171464876016439296/2021/7/11/headshots16170695297430-1626006880779826347793.jpg", | |
"https://preview.redd.it/do-you-think-joker-should-be-unpredictable-enough-to-put-up-v0-6a2ax4ngtlaa1.jpg?auto=webp&s=f8762e6a1b40642bcae5900bac184fc597131503", | |
] | |
texts = [ | |
['remote control', 'cat'], | |
['elon musk', 'mark zuckerberg', 'jeff bezos', 'bill gates'], | |
['batman', 'joker', 'shoe', 'belt', 'purple suit'], | |
] # note that the number of prompt per image can be different | |
# download images + convert to PIL.Image | |
images_pil = [Image.open(requests.get(url, stream=True).raw) for url in urls] | |
images = torch.stack([gem_img_transform(img) for img in images_pil]).to(device) | |
with torch.no_grad(): | |
# return list with logits of size [1, num_prompt, W, H] | |
logits_list = gem_model.batched_forward(images, texts) | |
print(logits_list[0].shape) # torch.Size([2, 448, 448]) | |
print(logits_list[1].shape) # torch.Size([4, 448, 448]) | |
print(logits_list[2].shape) # torch.Size([5, 448, 448]) | |
for i, _logits in enumerate(logits_list): | |
visualize(images[i], texts[i], _logits) # (optional visualization) | |