import gradio as gr
import numpy as np
from CLIP.clip import ClipWrapper, saliency_configs
from time import time
from matplotlib import pyplot as plt
import io
from PIL import Image, ImageDraw, ImageFont
import matplotlib
matplotlib.use("Agg")
tag = """
"""
def plot_to_png(fig):
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
img = np.array(Image.open(buf)).astype(np.uint8)
return img
def add_text_to_image(
image: np.ndarray,
text,
position,
color="rgb(255, 255, 255)",
fontsize=60,
):
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
draw.text(
position,
text,
fill=color,
font=ImageFont.truetype(
"/usr/share/fonts/truetype/lato/Lato-Medium.ttf", fontsize
),
)
return np.array(image)
def generate_relevancy(
img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool
):
labels = labels.split(",")
if len(labels) > 32:
labels = labels[:32]
prompts = [prompt]
resize_shape = np.array(img.shape[:2])
resize_shape = tuple(
((resize_shape / resize_shape.max()) * 224 * 4).astype(int).tolist()
)
img = np.asarray(Image.fromarray(img).resize(resize_shape))
assert img.dtype == np.uint8
h, w, c = img.shape
start = time()
try:
grads = ClipWrapper.get_clip_saliency(
img=img,
text_labels=np.array(labels),
prompts=prompts,
**saliency_configs[saliency_config](h),
)[0]
except Exception as e:
print(e)
return (
[img],
tag,
)
print("inference took", float(time() - start))
if subtract_mean:
grads -= grads.mean(axis=0)
grads = grads.cpu().numpy()
vmin = 0.002
cmap = plt.get_cmap("jet")
vmax = 0.008
returns = []
for label_grad, label in zip(grads, labels):
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.axis("off")
ax.imshow(img)
grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0)
colored_grad = cmap(grad)
grad = 1 - grad
colored_grad[..., -1] = grad * 0.7
colored_grad = add_text_to_image(
(colored_grad * 255).astype(np.uint8), text=label, position=(0, 0)
)
colored_grad = colored_grad.astype(float) / 255
ax.imshow(colored_grad)
plt.tight_layout(pad=0)
returns.append(plot_to_png(fig))
plt.close(fig)
return (
returns,
tag,
)
iface = gr.Interface(
title="Semantic Abstraction Multi-scale Relevancy Extractor",
description="""A CPU-only demo of [Semantic Abstraction](https://semantic-abstraction.cs.columbia.edu/)'s Multi-Scale Relevancy Extractor. To run GPU inference locally, use the [official codebase release](https://github.com/columbia-ai-robotics/semantic-abstraction).
This relevancy extractor builds heavily on [Chefer et al.'s codebase](https://github.com/hila-chefer/Transformer-MM-Explainability) and [CLIP on Wheels' codebase](https://cow.cs.columbia.edu/).""",
fn=generate_relevancy,
cache_examples=True,
inputs=[
gr.Image(type="numpy", label="Image"),
gr.Textbox(label="Labels (comma separated without spaces in between)"),
gr.Textbox(
label="Prompt. (Make sure to include '{}' in the prompt like examples below)"
),
gr.Dropdown(
value="ours",
choices=["ours", "ours_fast", "chefer_et_al"],
label="Relevancy Configuration",
),
gr.Checkbox(value=True, label="subtract mean"),
],
outputs=[
gr.Gallery(label="Relevancy Maps", type="numpy"),
gr.HTML(value=tag),
],
examples=[
[
"https://semantic-abstraction.cs.columbia.edu/downloads/gameroom.png",
"basketball jersey,nintendo switch,television,ping pong table,vase,fireplace,abstract painting of a vespa,carpet,wall",
"a photograph of a {} in a home.",
"ours_fast",
True,
],
[
"https://semantic-abstraction.cs.columbia.edu/downloads/livingroom.png",
"monopoly boardgame set,door knob,sofa,coffee table,plant,carpet,wall",
"a photograph of a {} in a home.",
"ours_fast",
True,
],
[
"https://semantic-abstraction.cs.columbia.edu/downloads/fireplace.png",
"fireplace,beige armchair,candle,large indoor plant in a pot,forest painting,cheetah-patterned pillow,floor,carpet,wall",
"a photograph of a {} in a home.",
"ours_fast",
True,
],
[
"https://semantic-abstraction.cs.columbia.edu/downloads/walle.png",
"WALL-E,a fire extinguisher",
"a 3D render of {}.",
"ours_fast",
True,
],
],
)
iface.launch()