Harshithtd's picture
Create app.py
b1c60a9 verified
raw
history blame
1.84 kB
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np
import cv2
# Load the pre-trained CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def apply_gradcam(image, text):
inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
similarity = torch.nn.functional.cosine_similarity(image_embeds, text_embeds)
similarity.backward()
gradients = model.get_input_embeddings().weight.grad
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
activations = outputs.last_hidden_state
for i in range(pooled_gradients.shape[0]):
activations[:, i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(activations, dim=1).squeeze().detach().cpu().numpy()
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
heatmap = cv2.resize(heatmap, (image.size[0], image.size[1]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(np.array(image), 0.6, heatmap, 0.4, 0)
return superimposed_img
def highlight_image(image, text):
highlighted_image = apply_gradcam(image, text)
return Image.fromarray(highlighted_image)
# Define Gradio interface
iface = gr.Interface(
fn=highlight_image,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Text Description")],
outputs=gr.Image(type="pil"),
title="Image Text Highlight",
description="Upload an image and provide a text description to highlight the relevant part of the image."
)
# Launch the Gradio app
iface.launch()