import gradio as gr import numpy as np from CLIP.clip import ClipWrapper, saliency_configs from time import time from matplotlib import pyplot as plt import io from PIL import Image, ImageDraw, ImageFont import matplotlib matplotlib.use("Agg") tag = """ """ def plot_to_png(fig): buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) img = np.array(Image.open(buf)).astype(np.uint8) return img def add_text_to_image( image: np.ndarray, text, position, color="rgb(255, 255, 255)", fontsize=60, ): image = Image.fromarray(image) draw = ImageDraw.Draw(image) draw.text( position, text, fill=color, font=ImageFont.truetype( "/usr/share/fonts/truetype/lato/Lato-Medium.ttf", fontsize ), ) return np.array(image) def generate_relevancy( img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool ): labels = labels.split(",") if len(labels) > 32: labels = labels[:32] prompts = [prompt] resize_shape = np.array(img.shape[:2]) resize_shape = tuple( ((resize_shape / resize_shape.max()) * 224 * 4).astype(int).tolist() ) img = np.asarray(Image.fromarray(img).resize(resize_shape)) assert img.dtype == np.uint8 h, w, c = img.shape start = time() try: grads = ClipWrapper.get_clip_saliency( img=img, text_labels=np.array(labels), prompts=prompts, **saliency_configs[saliency_config](h), )[0] except Exception as e: print(e) return ( [img], tag, ) print("inference took", float(time() - start)) if subtract_mean: grads -= grads.mean(axis=0) grads = grads.cpu().numpy() vmin = 0.002 cmap = plt.get_cmap("jet") vmax = 0.008 returns = [] for label_grad, label in zip(grads, labels): fig, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.axis("off") ax.imshow(img) grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0) colored_grad = cmap(grad) grad = 1 - grad colored_grad[..., -1] = grad * 0.7 colored_grad = add_text_to_image( (colored_grad * 255).astype(np.uint8), text=label, position=(0, 0) ) colored_grad = colored_grad.astype(float) / 255 ax.imshow(colored_grad) plt.tight_layout(pad=0) returns.append(plot_to_png(fig)) plt.close(fig) return ( returns, tag, ) iface = gr.Interface( title="Semantic Abstraction Multi-scale Relevancy Extractor", description="""A CPU-only demo of [Semantic Abstraction](https://semantic-abstraction.cs.columbia.edu/)'s Multi-Scale Relevancy Extractor. To run GPU inference locally, use the [official codebase release](https://github.com/columbia-ai-robotics/semantic-abstraction). This relevancy extractor builds heavily on [Chefer et al.'s codebase](https://github.com/hila-chefer/Transformer-MM-Explainability) and [CLIP on Wheels' codebase](https://cow.cs.columbia.edu/).""", fn=generate_relevancy, cache_examples=True, inputs=[ gr.Image(type="numpy", label="Image"), gr.Textbox(label="Labels (comma separated without spaces in between)"), gr.Textbox( label="Prompt. (Make sure to include '{}' in the prompt like examples below)" ), gr.Dropdown( value="ours", choices=["ours", "ours_fast", "chefer_et_al"], label="Relevancy Configuration", ), gr.Checkbox(value=True, label="subtract mean"), ], outputs=[ gr.Gallery(label="Relevancy Maps", type="numpy"), gr.HTML(value=tag), ], examples=[ [ "https://semantic-abstraction.cs.columbia.edu/downloads/gameroom.png", "basketball jersey,nintendo switch,television,ping pong table,vase,fireplace,abstract painting of a vespa,carpet,wall", "a photograph of a {} in a home.", "ours_fast", True, ], [ "https://semantic-abstraction.cs.columbia.edu/downloads/livingroom.png", "monopoly boardgame set,door knob,sofa,coffee table,plant,carpet,wall", "a photograph of a {} in a home.", "ours_fast", True, ], [ "https://semantic-abstraction.cs.columbia.edu/downloads/fireplace.png", "fireplace,beige armchair,candle,large indoor plant in a pot,forest painting,cheetah-patterned pillow,floor,carpet,wall", "a photograph of a {} in a home.", "ours_fast", True, ], [ "https://semantic-abstraction.cs.columbia.edu/downloads/walle.png", "WALL-E,a fire extinguisher", "a 3D render of {}.", "ours_fast", True, ], ], ) iface.launch()