Spaces:
Runtime error
Runtime error
from PIL import Image | |
import numpy as np | |
import cv2 as cv2 | |
import torch | |
import requests | |
import gradio as gr | |
import gem | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# OpenCLIP | |
model_name = 'ViT-B-16-quickgelu' | |
pretrained = 'metaclip_400m' | |
preprocess = gem.get_gem_img_transform() | |
# global gem_model | |
gem_model = gem.create_gem_model(model_name=model_name, pretrained=pretrained, device=device) | |
image_source = "image" | |
_MODELS = { | |
"OpenAI": ('ViT-B-16', 'openai'), | |
"MetaCLIP": ('ViT-B-16-quickgelu', 'metaclip_400m'), | |
"OpenCLIP": ('ViT-B-16', 'laion400m_e32') | |
} | |
def change_weights(pretrained_weights): | |
""" Handle changing model's weights triggered by a Dropdown module change.""" | |
curr_model = pretrained_weights | |
_new_model = _MODELS[pretrained_weights] | |
print(_new_model) | |
global gem_model | |
gem_model = gem.create_gem_model(model_name=_new_model[0], pretrained=_new_model[1], device=device) | |
def change_to_url(url): | |
img_pil = Image.open(requests.get(url, stream=True).raw).convert('RGB') | |
return img_pil | |
def viz_func(url, image, text, model_weights): | |
image_torch = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
logits = gem_model(image_torch, [text]) | |
logits = logits[0].detach().cpu().numpy() | |
img_cv = cv2.cvtColor(np.array(image.resize((448, 448))), cv2.COLOR_RGB2BGR) | |
logit_cs_viz = (logits * 255).astype('uint8') | |
heat_maps_cs = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logit_cs_viz] | |
vizs = [0.4 * img_cv + 0.6 * heat_map for heat_map in heat_maps_cs] | |
vizs = [cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB) for viz in vizs] | |
return vizs[0] | |
inputs = [ | |
gr.Textbox(label="url to the image", ), | |
gr.Image(type="pil"), | |
gr.Textbox(label="Text Prompt"), | |
gr.Dropdown(["OpenAI", "MetaCLIP", "OpenCLIP"], label="Pretrained Weights", value="MetaCLIP", | |
info='It can take a few second for the model to be updated.'), | |
] | |
with gr.Blocks() as demo: | |
inputs[-1].change(fn=change_weights, inputs=[inputs[-1]]) | |
inputs[0].change(fn=change_to_url, outputs=inputs[1], inputs=inputs[0]) | |
interact = gr.Interface( | |
title="GEM: Grounding Everything Module (link to paper/code)", | |
description="Grounding Everything: Emerging Localization Properties in Vision-Language Transformers", | |
fn=viz_func, | |
inputs=inputs, | |
outputs=["image"], | |
) | |
gr.Examples( | |
[ | |
["assets/cats_remote_control.jpeg", "cat"], | |
["assets/cats_remote_control.jpeg", "remote control"], | |
["assets/elon_jeff_mark.jpeg", "elon musk"], | |
["assets/elon_jeff_mark.jpeg", "mark zuckerberg"], | |
["assets/elon_jeff_mark.jpeg", "jeff bezos"], | |
], | |
[inputs[1], inputs[2]] | |
) | |
# demo.launch(server_port=5152) | |
demo.launch() | |