Spaces:
Running
Running
import streamlit as st | |
from streamlit_sortables import sort_items | |
from torchvision import transforms | |
from transformers import CLIPProcessor, CLIPModel | |
from torchmetrics.multimodal import CLIPScore | |
import torch | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
from datasets import load_dataset, Dataset, load_from_disk | |
import os | |
import clip | |
def compute_clip_score(promptbook, device, drop_negative=False): | |
# if 'clip_score' in promptbook.columns: | |
# print('==> Skipping CLIP-Score computation') | |
# return | |
print('==> CLIP-Score computation started') | |
clip_scores = [] | |
to_tensor = transforms.ToTensor() | |
# metric = CLIPScore(model_name_or_path='openai/clip-vit-base-patch16').to(DEVICE) | |
metric = CLIPScore(model_name_or_path='openai/clip-vit-large-patch14').to(device) | |
for i in tqdm(range(0, len(promptbook), BATCH_SIZE)): | |
images = [] | |
prompts = list(promptbook.prompt.values[i:i+BATCH_SIZE]) | |
for image in promptbook.image.values[i:i+BATCH_SIZE]: | |
images.append(to_tensor(image)) | |
with torch.no_grad(): | |
x = metric.processor(text=prompts, images=images, return_tensors='pt', padding=True) | |
img_features = metric.model.get_image_features(x['pixel_values'].to(device)) | |
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) | |
txt_features = metric.model.get_text_features(x['input_ids'].to(device), x['attention_mask'].to(device)) | |
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) | |
scores = 100 * (img_features * txt_features).sum(axis=-1).detach().cpu() | |
if drop_negative: | |
scores = torch.max(scores, torch.zeros_like(scores)) | |
clip_scores += [round(s.item(), 4) for s in scores] | |
promptbook['clip_score'] = np.asarray(clip_scores) | |
print('==> CLIP-Score computation completed') | |
return promptbook | |
def compute_clip_score_hmd(promptbook): | |
metric_cpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('cpu') | |
metric_gpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('mps') | |
for idx in promptbook.index: | |
clip_score_hm = promptbook.loc[idx, 'clip_score'] | |
with torch.no_grad(): | |
image = promptbook.loc[idx, 'image'] | |
image.save(f"./tmp/{promptbook.loc[idx, 'image_id']}.png") | |
image = transforms.ToTensor()(image) | |
image_cpu = torch.unsqueeze(image, dim=0).to('cpu') | |
image_gpu = torch.unsqueeze(image, dim=0).to('mps') | |
prompts = [promptbook.loc[idx, 'prompt']] | |
clip_score_cpu = metric_cpu(image_cpu, prompts) | |
clip_score_gpu = metric_gpu(image_gpu, prompts) | |
print( | |
f'==> clip_score_hm: {clip_score_hm:.4f}, clip_score_cpu: {clip_score_cpu:.4f}, clip_score_gpu: {clip_score_gpu:.4f}') | |
def compute_clip_score_transformers(promptbook, device='cpu'): | |
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
with torch.no_grad(): | |
inputs = processor(text=promptbook.prompt.tolist(), images=promptbook.image.tolist(), return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
promptbook.loc[:, 'clip_score'] = logits_per_image[:, 0].tolist() | |
return promptbook | |
def compute_clip_score_clip(promptbook, device='cpu'): | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
with torch.no_grad(): | |
for idx in promptbook.index: | |
# image_input = preprocess(promptbook.loc[idx, 'image']).unsqueeze(0).to(device) | |
image_inputs = preprocess(promptbook.image.tolist()).to(device) | |
text_inputs = torch.cat([clip.tokenize(promptbook.prompt.tolist()).to(device)]).to(device) | |
image_features = model.encode_image(image_inputs) | |
text_features = model.encode_text(text_inputs) | |
probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
promptbook.loc[:, 'clip_score'] = probs[:, 0].tolist() | |
return promptbook | |
if __name__ == "__main__": | |
BATCH_SIZE = 200 | |
# DEVICE = 'mps' if torch.has_mps else 'cpu' | |
print(torch.__version__) | |
images_ds = load_from_disk(os.path.join(os.pardir, 'data', 'promptbook')) | |
images_ds = images_ds.sort(['prompt_id', 'modelVersion_id']) | |
print(images_ds) | |
print(type(images_ds[0]['image'])) | |
promptbook_hmd = pd.DataFrame(images_ds[:20]) | |
promptbook_new = promptbook_hmd.drop(columns=['clip_score']) | |
promptbook_cpu = compute_clip_score(promptbook_new.copy(deep=True), device='cpu') | |
promptbook_mps = compute_clip_score(promptbook_new.copy(deep=True), device='mps') | |
promptbook_tra_cpu = compute_clip_score_transformers(promptbook_new.copy(deep=True)) | |
promptbook_tra_mps = compute_clip_score_transformers(promptbook_new.copy(deep=True), device='mps') | |
# | |
for idx in promptbook_mps.index: | |
print( | |
'image id: ', promptbook_mps['image_id'][idx], | |
'mps: ', promptbook_mps['clip_score'][idx], | |
'cpu: ', promptbook_cpu['clip_score'][idx], | |
'tra cpu: ', promptbook_tra_cpu['clip_score'][idx], | |
'tra mps: ', promptbook_tra_mps['clip_score'][idx], | |
'hmd: ', promptbook_hmd['clip_score'][idx] | |
) | |
# | |
# compute_clip_score_hmd(promptbook_hmd) | |