Spaces:
Runtime error
Runtime error
"""Evaluates cross-modal correspondence of CLIP on PNG images.""" | |
import os | |
import sys | |
from os.path import join, exists | |
import warnings | |
warnings.filterwarnings('ignore') | |
from clip_grounding.utils.paths import REPO_PATH | |
sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/")) | |
import torch | |
import CLIP.clip as clip | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
from captum.attr import visualization | |
from torchmetrics import JaccardIndex | |
from collections import defaultdict | |
from IPython.core.display import display, HTML | |
from skimage import filters | |
from CLIP_explainability.utils import interpret, show_img_heatmap, show_txt_heatmap, color, _tokenizer | |
from clip_grounding.datasets.png import PNG | |
from clip_grounding.utils.image import pad_to_square | |
from clip_grounding.utils.visualize import show_grid_of_images | |
from clip_grounding.utils.log import tqdm_iterator, print_update | |
# global usage | |
# specify device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load CLIP model | |
model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
def show_cam(mask): | |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) | |
heatmap = np.float32(heatmap) / 255 | |
cam = heatmap | |
cam = cam / np.max(cam) | |
return cam | |
def interpret_and_generate(model, img, texts, orig_image, return_outputs=False, show=True): | |
text = clip.tokenize(texts).to(device) | |
R_text, R_image = interpret(model=model, image=img, texts=text, device=device) | |
batch_size = text.shape[0] | |
outputs = [] | |
for i in range(batch_size): | |
text_scores, text_tokens_decoded = show_txt_heatmap(texts[i], text[i], R_text[i], show=show) | |
image_relevance = show_img_heatmap(R_image[i], img, orig_image=orig_image, device=device, show=show) | |
plt.show() | |
outputs.append({"text_scores": text_scores, "image_relevance": image_relevance, "tokens_decoded": text_tokens_decoded}) | |
if return_outputs: | |
return outputs | |
def process_entry_text_to_image(entry, unimodal=False): | |
image = entry['image'] | |
text_mask = entry['text_mask'] | |
text = entry['text'] | |
orig_image = pad_to_square(image) | |
img = preprocess(orig_image).unsqueeze(0).to(device) | |
text_index = text_mask.argmax() | |
texts = [text[text_index]] if not unimodal else [''] | |
return img, texts, orig_image | |
def preprocess_ground_truth_mask(mask, resize_shape): | |
mask = Image.fromarray(mask.astype(np.uint8) * 255) | |
mask = pad_to_square(mask, color=0) | |
mask = mask.resize(resize_shape) | |
mask = np.asarray(mask) / 255. | |
return mask | |
def apply_otsu_threshold(relevance_map): | |
threshold = filters.threshold_otsu(relevance_map) | |
otsu_map = (relevance_map > threshold).astype(np.uint8) | |
return otsu_map | |
def evaluate_text_to_image(method, dataset, debug=False): | |
instance_level_metrics = defaultdict(list) | |
entry_level_metrics = defaultdict(list) | |
jaccard = JaccardIndex(num_classes=2) | |
jaccard = jaccard.to(device) | |
num_iter = len(dataset) | |
if debug: | |
num_iter = 100 | |
iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset") | |
for idx in iterator: | |
instance = dataset[idx] | |
instance_iou = 0. | |
for entry in instance: | |
# preprocess the image and text | |
unimodal = True if method == "clip-unimodal" else False | |
test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=unimodal) | |
if method in ["clip", "clip-unimodal"]: | |
# compute the relevance scores | |
outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False) | |
# use the image relevance score to compute IoU w.r.t. ground truth segmentation masks | |
# NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs | |
relevance_map = outputs[0]["image_relevance"] | |
elif method == "random": | |
relevance_map = np.random.uniform(low=0., high=1., size=tuple(test_img.shape[2:])) | |
otsu_relevance_map = apply_otsu_threshold(relevance_map) | |
ground_truth_mask = entry["image_mask"] | |
ground_truth_mask = preprocess_ground_truth_mask(ground_truth_mask, relevance_map.shape) | |
entry_iou = jaccard( | |
torch.from_numpy(otsu_relevance_map).to(device), | |
torch.from_numpy(ground_truth_mask.astype(np.uint8)).to(device), | |
) | |
entry_iou = entry_iou.item() | |
instance_iou += (entry_iou / len(entry)) | |
entry_level_metrics["iou"].append(entry_iou) | |
# capture instance (image-sentence pair) level IoU | |
instance_level_metrics["iou"].append(instance_iou) | |
average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()} | |
return ( | |
average_metrics, | |
instance_level_metrics, | |
entry_level_metrics | |
) | |
def process_entry_image_to_text(entry, unimodal=False): | |
if not unimodal: | |
if len(np.asarray(entry["image"]).shape) == 3: | |
mask = np.repeat(np.expand_dims(entry['image_mask'], -1), 3, axis=-1) | |
else: | |
mask = np.asarray(entry['image_mask']) | |
masked_image = (mask * np.asarray(entry['image'])).astype(np.uint8) | |
masked_image = Image.fromarray(masked_image) | |
orig_image = pad_to_square(masked_image) | |
img = preprocess(orig_image).unsqueeze(0).to(device) | |
else: | |
orig_image_shape = max(np.asarray(entry['image']).shape[:2]) | |
orig_image = Image.fromarray(np.zeros((orig_image_shape, orig_image_shape, 3), dtype=np.uint8)) | |
# orig_image = Image.fromarray(np.random.randint(0, 256, (orig_image_shape, orig_image_shape, 3), dtype=np.uint8)) | |
img = preprocess(orig_image).unsqueeze(0).to(device) | |
texts = [' '.join(entry['text'])] | |
return img, texts, orig_image | |
def process_text_mask(text, text_mask, tokens): | |
token_level_mask = np.zeros(len(tokens)) | |
for label, subtext in zip(text_mask, text): | |
subtext_tokens=_tokenizer.encode(subtext) | |
subtext_tokens_decoded=[_tokenizer.decode([a]) for a in subtext_tokens] | |
if label == 1: | |
start = tokens.index(subtext_tokens_decoded[0]) | |
end = tokens.index(subtext_tokens_decoded[-1]) | |
token_level_mask[start:end + 1] = 1 | |
return token_level_mask | |
def evaluate_image_to_text(method, dataset, debug=False, clamp_sentence_len=70): | |
instance_level_metrics = defaultdict(list) | |
entry_level_metrics = defaultdict(list) | |
# skipped if text length > 77 which is CLIP limit | |
num_entries_skipped = 0 | |
num_total_entries = 0 | |
num_iter = len(dataset) | |
if debug: | |
num_iter = 100 | |
jaccard_image_to_text = JaccardIndex(num_classes=2).to(device) | |
iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset") | |
for idx in iterator: | |
instance = dataset[idx] | |
instance_iou = 0. | |
for entry in instance: | |
num_total_entries += 1 | |
# preprocess the image and text | |
unimodal = True if method == "clip-unimodal" else False | |
img, texts, orig_image = process_entry_image_to_text(entry, unimodal=unimodal) | |
appx_total_sent_len = np.sum([len(x.split(" ")) for x in texts]) | |
if appx_total_sent_len > clamp_sentence_len: | |
# print(f"Skipping an entry since it's text has appx"\ | |
# " {appx_total_sent_len} while CLIP cannot process beyond {clamp_sentence_len}") | |
num_entries_skipped += 1 | |
continue | |
# compute the relevance scores | |
if method in ["clip", "clip-unimodal"]: | |
try: | |
outputs = interpret_and_generate(model, img, texts, orig_image, return_outputs=True, show=False) | |
except: | |
num_entries_skipped += 1 | |
continue | |
elif method == "random": | |
text = texts[0] | |
text_tokens = _tokenizer.encode(text) | |
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens] | |
outputs = [ | |
{ | |
"text_scores": np.random.uniform(low=0., high=1., size=len(text_tokens_decoded)), | |
"tokens_decoded": text_tokens_decoded, | |
} | |
] | |
# use the text relevance score to compute IoU w.r.t. ground truth text masks | |
# NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs | |
token_relevance_scores = outputs[0]["text_scores"] | |
if isinstance(token_relevance_scores, torch.Tensor): | |
token_relevance_scores = token_relevance_scores.cpu().numpy() | |
token_relevance_scores = apply_otsu_threshold(token_relevance_scores) | |
token_ground_truth_mask = process_text_mask(entry["text"], entry["text_mask"], outputs[0]["tokens_decoded"]) | |
entry_iou = jaccard_image_to_text( | |
torch.from_numpy(token_relevance_scores).to(device), | |
torch.from_numpy(token_ground_truth_mask.astype(np.uint8)).to(device), | |
) | |
entry_iou = entry_iou.item() | |
instance_iou += (entry_iou / len(entry)) | |
entry_level_metrics["iou"].append(entry_iou) | |
# capture instance (image-sentence pair) level IoU | |
instance_level_metrics["iou"].append(instance_iou) | |
print(f"CAUTION: Skipped {(num_entries_skipped / num_total_entries) * 100} % since these had length > 77 (CLIP limit).") | |
average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()} | |
return ( | |
average_metrics, | |
instance_level_metrics, | |
entry_level_metrics | |
) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser("Evaluate Image-to-Text & Text-to-Image model") | |
parser.add_argument( | |
"--eval_method", type=str, default="clip", | |
choices=["clip", "random", "clip-unimodal"], | |
help="Evaluation method to use", | |
) | |
parser.add_argument( | |
"--ignore_cache", action="store_true", | |
help="Ignore cache and force re-generation of the results", | |
) | |
parser.add_argument( | |
"--debug", action="store_true", | |
help="Run evaluation on a small subset of the dataset", | |
) | |
args = parser.parse_args() | |
print_update("Using evaluation method: {}".format(args.eval_method)) | |
clip.clip._MODELS = { | |
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", | |
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", | |
} | |
# specify device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load CLIP model | |
print_update("Loading CLIP model...") | |
model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
print() | |
# load PNG dataset | |
print_update("Loading PNG dataset...") | |
dataset = PNG(dataset_root=join(REPO_PATH, "data", "panoptic_narrative_grounding"), split="val2017") | |
print() | |
# evaluate | |
# save metrics | |
metrics_dir = join(REPO_PATH, "outputs") | |
os.makedirs(metrics_dir, exist_ok=True) | |
metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_text2image_metrics.pt") | |
if (not exists(metrics_path)) or args.ignore_cache: | |
print_update("Computing metrics for text-to-image grounding") | |
average_metrics, instance_level_metrics, entry_level_metrics = evaluate_text_to_image( | |
args.eval_method, dataset, debug=args.debug, | |
) | |
metrics = { | |
"average_metrics": average_metrics, | |
"instance_level_metrics":instance_level_metrics, | |
"entry_level_metrics": entry_level_metrics | |
} | |
torch.save(metrics, metrics_path) | |
print("TEXT2IMAGE METRICS SAVED TO:", metrics_path) | |
else: | |
print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.") | |
metrics = torch.load(metrics_path) | |
average_metrics = metrics["average_metrics"] | |
print("TEXT2IMAGE METRICS:", np.round(average_metrics["iou"], 4)) | |
print() | |
metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_image2text_metrics.pt") | |
if (not exists(metrics_path)) or args.ignore_cache: | |
print_update("Computing metrics for image-to-text grounding") | |
average_metrics, instance_level_metrics, entry_level_metrics = evaluate_image_to_text( | |
args.eval_method, dataset, debug=args.debug, | |
) | |
torch.save( | |
{ | |
"average_metrics": average_metrics, | |
"instance_level_metrics":instance_level_metrics, | |
"entry_level_metrics": entry_level_metrics | |
}, | |
metrics_path, | |
) | |
print("IMAGE2TEXT METRICS SAVED TO:", metrics_path) | |
else: | |
print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.") | |
metrics = torch.load(metrics_path) | |
average_metrics = metrics["average_metrics"] | |
print("IMAGE2TEXT METRICS:", np.round(average_metrics["iou"], 4)) | |