GEM / app.py
WalidBouss's picture
Initial commit :tada:
be1ec96
from PIL import Image
import numpy as np
import cv2 as cv2
import torch
import requests
import gradio as gr
import gem
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# OpenCLIP
model_name = 'ViT-B-16-quickgelu'
pretrained = 'metaclip_400m'
preprocess = gem.get_gem_img_transform()
# global gem_model
gem_model = gem.create_gem_model(model_name=model_name, pretrained=pretrained, device=device)
image_source = "image"
_MODELS = {
"OpenAI": ('ViT-B-16', 'openai'),
"MetaCLIP": ('ViT-B-16-quickgelu', 'metaclip_400m'),
"OpenCLIP": ('ViT-B-16', 'laion400m_e32')
}
def change_weights(pretrained_weights):
""" Handle changing model's weights triggered by a Dropdown module change."""
curr_model = pretrained_weights
_new_model = _MODELS[pretrained_weights]
print(_new_model)
global gem_model
gem_model = gem.create_gem_model(model_name=_new_model[0], pretrained=_new_model[1], device=device)
def change_to_url(url):
img_pil = Image.open(requests.get(url, stream=True).raw).convert('RGB')
return img_pil
def viz_func(url, image, text, model_weights):
image_torch = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = gem_model(image_torch, [text])
logits = logits[0].detach().cpu().numpy()
img_cv = cv2.cvtColor(np.array(image.resize((448, 448))), cv2.COLOR_RGB2BGR)
logit_cs_viz = (logits * 255).astype('uint8')
heat_maps_cs = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logit_cs_viz]
vizs = [0.4 * img_cv + 0.6 * heat_map for heat_map in heat_maps_cs]
vizs = [cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB) for viz in vizs]
return vizs[0]
inputs = [
gr.Textbox(label="url to the image", ),
gr.Image(type="pil"),
gr.Textbox(label="Text Prompt"),
gr.Dropdown(["OpenAI", "MetaCLIP", "OpenCLIP"], label="Pretrained Weights", value="MetaCLIP",
info='It can take a few second for the model to be updated.'),
]
with gr.Blocks() as demo:
inputs[-1].change(fn=change_weights, inputs=[inputs[-1]])
inputs[0].change(fn=change_to_url, outputs=inputs[1], inputs=inputs[0])
interact = gr.Interface(
title="GEM: Grounding Everything Module (link to paper/code)",
description="Grounding Everything: Emerging Localization Properties in Vision-Language Transformers",
fn=viz_func,
inputs=inputs,
outputs=["image"],
)
gr.Examples(
[
["assets/cats_remote_control.jpeg", "cat"],
["assets/cats_remote_control.jpeg", "remote control"],
["assets/elon_jeff_mark.jpeg", "elon musk"],
["assets/elon_jeff_mark.jpeg", "mark zuckerberg"],
["assets/elon_jeff_mark.jpeg", "jeff bezos"],
],
[inputs[1], inputs[2]]
)
# demo.launch(server_port=5152)
demo.launch()