Spaces:
Running
on
Zero
Running
on
Zero
# Authors: Hui Ren (rhfeiyang.github.io) | |
import os | |
import numpy as np | |
from torchvision import transforms | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from torch.autograd import Function | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel | |
from collections import OrderedDict | |
from transformers import BatchFeature | |
import clip | |
import copy | |
import lpips | |
from transformers import ViTImageProcessor, ViTModel | |
## CSD_CLIP | |
def convert_weights_float(model: nn.Module): | |
"""Convert applicable model parameters to fp32""" | |
def _convert_weights_to_fp32(l): | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |
l.weight.data = l.weight.data.float() | |
if l.bias is not None: | |
l.bias.data = l.bias.data.float() | |
if isinstance(l, nn.MultiheadAttention): | |
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: | |
tensor = getattr(l, attr) | |
if tensor is not None: | |
tensor.data = tensor.data.float() | |
for name in ["text_projection", "proj"]: | |
if hasattr(l, name): | |
attr = getattr(l, name) | |
if attr is not None: | |
attr.data = attr.data.float() | |
model.apply(_convert_weights_to_fp32) | |
class ReverseLayerF(Function): | |
def forward(ctx, x, alpha): | |
ctx.alpha = alpha | |
return x.view_as(x) | |
def backward(ctx, grad_output): | |
output = grad_output.neg() * ctx.alpha | |
return output, None | |
## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py | |
class ProjectionHead(nn.Module): | |
def __init__( | |
self, | |
embedding_dim, | |
projection_dim, | |
dropout=0 | |
): | |
super().__init__() | |
self.projection = nn.Linear(embedding_dim, projection_dim) | |
self.gelu = nn.GELU() | |
self.fc = nn.Linear(projection_dim, projection_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(projection_dim) | |
def forward(self, x): | |
projected = self.projection(x) | |
x = self.gelu(projected) | |
x = self.fc(x) | |
x = self.dropout(x) | |
x = x + projected | |
x = self.layer_norm(x) | |
return x | |
def convert_state_dict(state_dict): | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
if k.startswith("module."): | |
k = k.replace("module.", "") | |
new_state_dict[k] = v | |
return new_state_dict | |
def init_weights(m): | |
if isinstance(m, nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.normal_(m.bias, std=1e-6) | |
class Metric(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.image_preprocess = None | |
def load_image(self, image_path): | |
with open(image_path, 'rb') as f: | |
image = Image.open(f).convert("RGB") | |
return image | |
def load_image_path(self, image_path): | |
if isinstance(image_path, str): | |
# should be a image folder path | |
images_file = os.listdir(image_path) | |
images = [os.path.join(image_path, image) for image in images_file if | |
image.endswith(".jpg") or image.endswith(".png")] | |
if isinstance(image_path[0], str): | |
images = [self.load_image(image) for image in image_path] | |
elif isinstance(image_path[0], np.ndarray): | |
images = [Image.fromarray(image) for image in image_path] | |
elif isinstance(image_path[0], Image.Image): | |
images = image_path | |
else: | |
raise Exception("Invalid input") | |
return images | |
def preprocess_image(self, image, **kwargs): | |
if (isinstance(image, str) and os.path.isdir(image)) or (isinstance(image, list) and (isinstance(image[0], Image.Image) or isinstance(image[0], np.ndarray) or os.path.isfile(image[0]))): | |
input_data = self.load_image_path(image) | |
input_data = [self.image_preprocess(image, **kwargs) for image in input_data] | |
input_data = torch.stack(input_data) | |
elif os.path.isfile(image): | |
input_data = self.load_image(image) | |
input_data = self.image_preprocess(input_data, **kwargs) | |
input_data = input_data.unsqueeze(0) | |
elif isinstance(image, torch.Tensor): | |
raise Exception("Unsupported input") | |
return input_data | |
class Clip_Basic_Metric(Metric): | |
def __init__(self): | |
super().__init__() | |
self.tensor_preprocess = transforms.Compose([ | |
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), | |
# transforms.rescale | |
transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
self.image_preprocess = transforms.Compose([ | |
transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
class Clip_metric(Clip_Basic_Metric): | |
def __init__(self, target_style_prompt: str=None, clip_model_name="openai/clip-vit-large-patch14", device="cuda", | |
bath_size=8, alpha=0.5): | |
super().__init__() | |
self.device = device | |
self.alpha = alpha | |
self.model = (CLIPModel.from_pretrained(clip_model_name)).to(device) | |
self.processor = CLIPProcessor.from_pretrained(clip_model_name) | |
self.tokenizer = self.processor.tokenizer | |
self.image_processor = self.processor.image_processor | |
# self.style_class_features = self.get_text_features(self.styles).cpu() | |
self.style_class_features=[] | |
# self.noise_prompt_features = self.get_text_features("Noise") | |
self.model.eval() | |
self.batch_size = bath_size | |
if target_style_prompt is not None: | |
self.ref_style_features = self.get_text_features(target_style_prompt) | |
else: | |
self.ref_style_features = None | |
self.ref_image_style_prototype = None | |
def get_text_features(self, text): | |
prompt_encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device) | |
prompt_features = self.model.get_text_features(**prompt_encoding).to(self.device) | |
prompt_features = F.normalize(prompt_features, p=2, dim=-1) | |
return prompt_features | |
def get_image_features(self, images): | |
# if isinstance(image, torch.Tensor): | |
# self.tensor_transform(image) | |
# else: | |
# image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True) | |
images = self.load_image_path(images) | |
if isinstance(images, torch.Tensor): | |
images = self.tensor_preprocess(images) | |
data = {"pixel_values": images} | |
image_features = BatchFeature(data=data, tensor_type="pt") | |
else: | |
image_features = self.image_processor(images, return_tensors="pt", padding=True).to(self.device, | |
non_blocking=True) | |
image_features = self.model.get_image_features(**image_features).to(self.device) | |
image_features = F.normalize(image_features, p=2, dim=-1) | |
return image_features | |
def img_text_similarity(self, image_features, text=None): | |
if text is not None: | |
prompt_feature = self.get_text_features(text) | |
if isinstance(text, str): | |
prompt_feature = prompt_feature.repeat(len(image_features), 1) | |
else: | |
prompt_feature = self.ref_style_features | |
similarity_each = torch.einsum("nc, nc -> n", image_features, prompt_feature) | |
return similarity_each | |
def forward(self, output_imgs, prompt=None): | |
image_features = self.get_image_features(output_imgs) | |
# print(image_features) | |
style_score = self.img_text_similarity(image_features.mean(dim=0, keepdim=True)) | |
if prompt is not None: | |
content_score = self.img_text_similarity(image_features, prompt) | |
score = self.alpha * style_score + (1 - self.alpha) * content_score | |
return {"score": score, "style_score": style_score, "content_score": content_score} | |
else: | |
return {"style_score": style_score} | |
def content_score(self, output_imgs, prompt): | |
self.to(self.device) | |
image_features = self.get_image_features(output_imgs) | |
content_score_details = self.img_text_similarity(image_features, prompt) | |
self.to("cpu") | |
return {"CLIP_content_score": content_score_details.mean().cpu(), "CLIP_content_score_details": content_score_details.cpu()} | |
class CSD_CLIP(Clip_Basic_Metric): | |
"""backbone + projection head""" | |
def __init__(self, name='vit_large',content_proj_head='default', ckpt_path = "data/weights/CSD-checkpoint.pth", device="cuda", | |
alpha=0.5, **kwargs): | |
super(CSD_CLIP, self).__init__() | |
self.alpha = alpha | |
self.content_proj_head = content_proj_head | |
self.device = device | |
if name == 'vit_large': | |
clipmodel, _ = clip.load("ViT-L/14") | |
self.backbone = clipmodel.visual | |
self.embedding_dim = 1024 | |
elif name == 'vit_base': | |
clipmodel, _ = clip.load("ViT-B/16") | |
self.backbone = clipmodel.visual | |
self.embedding_dim = 768 | |
self.feat_dim = 512 | |
else: | |
raise Exception('This model is not implemented') | |
convert_weights_float(self.backbone) | |
self.last_layer_style = copy.deepcopy(self.backbone.proj) | |
if content_proj_head == 'custom': | |
self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim) | |
self.last_layer_content.apply(init_weights) | |
else: | |
self.last_layer_content = copy.deepcopy(self.backbone.proj) | |
self.backbone.proj = None | |
self.backbone.requires_grad_(False) | |
self.last_layer_style.requires_grad_(False) | |
self.last_layer_content.requires_grad_(False) | |
self.backbone.eval() | |
if ckpt_path is not None: | |
self.load_ckpt(ckpt_path) | |
self.to("cpu") | |
def load_ckpt(self, ckpt_path): | |
checkpoint = torch.load(ckpt_path, map_location="cpu") | |
state_dict = convert_state_dict(checkpoint['model_state_dict']) | |
msg = self.load_state_dict(state_dict, strict=False) | |
print(f"=> loaded CSD_CLIP checkpoint with msg {msg}") | |
def dtype(self): | |
return self.backbone.conv1.weight.dtype | |
def get_image_features(self, input_data, get_style=True,get_content=False,feature_alpha=None): | |
if isinstance(input_data, torch.Tensor): | |
input_data = self.tensor_preprocess(input_data) | |
elif (isinstance(input_data, str) and os.path.isdir(input_data)) or (isinstance(input_data, list) and (isinstance(input_data[0], Image.Image) or isinstance(input_data[0], np.ndarray) or os.path.isfile(input_data[0]))): | |
input_data = self.load_image_path(input_data) | |
input_data = [self.image_preprocess(image) for image in input_data] | |
input_data = torch.stack(input_data) | |
elif os.path.isfile(input_data): | |
input_data = self.load_image(input_data) | |
input_data = self.image_preprocess(input_data) | |
input_data = input_data.unsqueeze(0) | |
input_data = input_data.to(self.device) | |
style_output = None | |
feature = self.backbone(input_data) | |
if get_style: | |
style_output = feature @ self.last_layer_style | |
# style_output = style_output.mean(dim=0) | |
style_output = nn.functional.normalize(style_output, dim=-1, p=2) | |
content_output=None | |
if get_content: | |
if feature_alpha is not None: | |
reverse_feature = ReverseLayerF.apply(feature, feature_alpha) | |
else: | |
reverse_feature = feature | |
# if alpha is not None: | |
if self.content_proj_head == 'custom': | |
content_output = self.last_layer_content(reverse_feature) | |
else: | |
content_output = reverse_feature @ self.last_layer_content | |
content_output = nn.functional.normalize(content_output, dim=-1, p=2) | |
return feature, content_output, style_output | |
def define_ref_image_style_prototype(self, ref_image_path: str): | |
self.to(self.device) | |
_, _, self.ref_style_feature = self.get_image_features(ref_image_path) | |
self.to("cpu") | |
# self.ref_style_feature = self.ref_style_feature.mean(dim=0) | |
def forward(self, styled_data): | |
self.to(self.device) | |
# get_content_feature = original_data is not None | |
_, content_output, style_output = self.get_image_features(styled_data, get_content=False) | |
style_similarities = style_output @ self.ref_style_feature.T | |
mean_style_similarities = style_similarities.mean(dim=-1) | |
mean_style_similarity = mean_style_similarities.mean() | |
max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1) | |
max_style_similarity = max_style_similarities_v.mean() | |
self.to("cpu") | |
return {"CSD_similarity_mean": mean_style_similarity, "CSD_similarity_max": max_style_similarity, "CSD_similarity_mean_details": mean_style_similarities, | |
"CSD_similarity_max_v_details": max_style_similarities_v, "CSD_similarity_max_id_details": max_style_similarities_id} | |
def get_style_loss(self, styled_data): | |
_, _, style_output = self.get_image_features(styled_data, get_style=True, get_content=False) | |
style_similarity = (style_output @ self.ref_style_feature).mean() | |
loss = 1 - style_similarity | |
return loss.mean() | |
class LPIPS_metric(Metric): | |
def __init__(self, type="vgg", device="cuda"): | |
super(LPIPS_metric, self).__init__() | |
self.lpips = lpips.LPIPS(net=type) | |
self.device = device | |
self.image_preprocess = transforms.Compose([ | |
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.CenterCrop(256), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
self.to("cpu") | |
def forward(self, img1, img2): | |
self.to(self.device) | |
differences = [] | |
for i in range(0, len(img1), 50): | |
img1_batch = img1[i:i+50] | |
img2_batch = img2[i:i+50] | |
img1_batch = self.preprocess_image(img1_batch).to(self.device) | |
img2_batch = self.preprocess_image(img2_batch).to(self.device) | |
differences.append(self.lpips(img1_batch, img2_batch).squeeze()) | |
differences = torch.cat(differences) | |
difference = differences.mean() | |
# similarity = 1 - difference | |
self.to("cpu") | |
return {"LPIPS_content_difference": difference, "LPIPS_content_difference_details": differences} | |
class Vit_metric(Metric): | |
def __init__(self, device="cuda"): | |
super(Vit_metric, self).__init__() | |
self.device = device | |
self.model = ViTModel.from_pretrained('facebook/dino-vitb8').eval() | |
self.image_processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8') | |
self.to("cpu") | |
def get_image_features(self, images): | |
# if isinstance(image, torch.Tensor): | |
# self.tensor_transform(image) | |
# else: | |
# image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True) | |
images = self.load_image_path(images) | |
batch_size = 20 | |
all_image_features = [] | |
for i in range(0, len(images), batch_size): | |
image_batch = images[i:i+batch_size] | |
if isinstance(image_batch, torch.Tensor): | |
image_batch = self.tensor_preprocess(image_batch) | |
data = {"pixel_values": image_batch} | |
image_processed = BatchFeature(data=data, tensor_type="pt") | |
else: | |
image_processed = self.image_processor(image_batch, return_tensors="pt").to(self.device) | |
image_features = self.model(**image_processed).last_hidden_state.flatten(start_dim=1) | |
image_features = F.normalize(image_features, p=2, dim=-1) | |
all_image_features.append(image_features) | |
all_image_features = torch.cat(all_image_features) | |
return all_image_features | |
def content_metric(self, img1, img2): | |
self.to(self.device) | |
if not(isinstance(img1, torch.Tensor) and len(img1.shape) == 2): | |
img1 = self.get_image_features(img1) | |
if not(isinstance(img2, torch.Tensor) and len(img2.shape) == 2): | |
img2 = self.get_image_features(img2) | |
similarities = torch.einsum("nc, nc -> n", img1, img2) | |
similarity = similarities.mean() | |
# self.to("cpu") | |
return {"Vit_content_similarity": similarity, "Vit_content_similarity_details": similarities} | |
# style | |
def define_ref_image_style_prototype(self, ref_image_path: str): | |
self.to(self.device) | |
self.ref_style_feature = self.get_image_features(ref_image_path) | |
self.to("cpu") | |
def style_metric(self, styled_data): | |
self.to(self.device) | |
if isinstance(styled_data, torch.Tensor) and len(styled_data.shape) == 2: | |
style_output = styled_data | |
else: | |
style_output = self.get_image_features(styled_data) | |
style_similarities = style_output @ self.ref_style_feature.T | |
mean_style_similarities = style_similarities.mean(dim=-1) | |
mean_style_similarity = mean_style_similarities.mean() | |
max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1) | |
max_style_similarity = max_style_similarities_v.mean() | |
# self.to("cpu") | |
return {"Vit_style_similarity_mean": mean_style_similarity, "Vit_style_similarity_max": max_style_similarity, "Vit_style_similarity_mean_details": mean_style_similarities, | |
"Vit_style_similarity_max_v_details": max_style_similarities_v, "Vit_style_similarity_max_id_details": max_style_similarities_id} | |
def forward(self, styled_data, original_data=None): | |
self.to(self.device) | |
styled_features = self.get_image_features(styled_data) | |
ret ={} | |
if original_data is not None: | |
content_metric = self.content_metric(styled_features, original_data) | |
ret["Vit_content"] = content_metric | |
style_metric = self.style_metric(styled_features) | |
ret["Vit_style"] = style_metric | |
self.to("cpu") | |
return ret | |
class StyleContentMetric(nn.Module): | |
def __init__(self, style_ref_image_folder, device="cuda"): | |
super(StyleContentMetric, self).__init__() | |
self.device = device | |
self.clip_style_metric = CSD_CLIP(device=device) | |
self.ref_image_file = os.listdir(style_ref_image_folder) | |
self.ref_image_file = [i for i in self.ref_image_file if i.endswith(".jpg") or i.endswith(".png")] | |
self.ref_image_file.sort() | |
self.ref_image_file = np.array(self.ref_image_file) | |
ref_image_file_path = [os.path.join(style_ref_image_folder, i) for i in self.ref_image_file] | |
self.clip_style_metric.define_ref_image_style_prototype(ref_image_file_path) | |
self.vit_metric = Vit_metric(device=device) | |
self.vit_metric.define_ref_image_style_prototype(ref_image_file_path) | |
self.lpips_metric = LPIPS_metric(device=device) | |
self.clip_content_metric = Clip_metric(alpha=0, target_style_prompt=None) | |
self.to("cpu") | |
def forward(self, styled_data, original_data=None, content_caption=None): | |
ret ={} | |
csd_score = self.clip_style_metric(styled_data) | |
csd_score["max_query"] = self.ref_image_file[csd_score["CSD_similarity_max_id_details"].cpu()].tolist() | |
torch.cuda.empty_cache() | |
ret["Style_CSD"] = csd_score | |
vit_score = self.vit_metric(styled_data, original_data) | |
torch.cuda.empty_cache() | |
vit_style = vit_score["Vit_style"] | |
vit_style["max_query"] = self.ref_image_file[vit_style["Vit_style_similarity_max_id_details"].cpu()].tolist() | |
ret["Style_VIT"] = vit_style | |
if original_data is not None: | |
vit_content = vit_score["Vit_content"] | |
ret["Content_VIT"] = vit_content | |
lpips_score = self.lpips_metric(styled_data, original_data) | |
torch.cuda.empty_cache() | |
ret["Content_LPIPS"] = lpips_score | |
if content_caption is not None: | |
clip_content = self.clip_content_metric.content_score(styled_data, content_caption) | |
ret["Content_CLIP"] = clip_content | |
torch.cuda.empty_cache() | |
for type_key, type_value in ret.items(): | |
for key, value in type_value.items(): | |
if isinstance(value, torch.Tensor): | |
if value.numel() == 1: | |
ret[type_key][key] = round(value.item(), 4) | |
else: | |
ret[type_key][key] = value.tolist() | |
ret[type_key][key] = [round(v, 4) for v in ret[type_key][key]] | |
self.to("cpu") | |
ret["ref_image_file"] = self.ref_image_file.tolist() | |
return ret | |
if __name__ == "__main__": | |
with torch.no_grad(): | |
metric = StyleContentMetric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Art_styles/camille-pissarro/impressionism/split_5/paintings") | |
score = metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500", | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings") | |
print(score) | |
lpips = LPIPS_metric() | |
score = lpips("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500") | |
print("lpips", score) | |
clip_metric = CSD_CLIP() | |
clip_metric.define_ref_image_style_prototype( | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings") | |
score = clip_metric( | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500") | |
print("subset3-subset3_sd14_converted", score) | |
score = clip_metric( | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500") | |
print("subset3-photo", score) | |
score = clip_metric( | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings") | |
print("subset3-subset1", score) | |
score = clip_metric( | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/andy-warhol/pop_art/subset1/paintings") | |
print("subset3-andy", score) | |
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", "A painting") | |
# print("subset3",score) | |
# score_subset2 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset2/paintings") | |
# print("subset2",score_subset2) | |
# score_subset3 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings") | |
# print("subset3",score_subset3) | |
# | |
# score_subset3_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500") | |
# print("subset3-subset3_sd14_converted" , score_subset3_converted) | |
# | |
# score_subset3_coco_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/coco_converted_photo/500") | |
# print("subset3-subset3_coco_converted" , score_subset3_coco_converted) | |
# | |
# clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500") | |
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500") | |
# print("photo500_1-sketch" ,score) | |
# | |
# clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500") | |
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500_new") | |
# print("photo500_1-photo500_2" ,score) | |
# from custom_datasets.imagepair import ImageSet | |
# import matplotlib.pyplot as plt | |
# dataset = ImageSet(folder = "/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", | |
# caption_path="/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/captions", | |
# keep_in_mem=False) | |
# for sample in dataset: | |
# score = clip_metric.content_score(sample["image"], sample["caption"][0]) | |
# plt.imshow(sample["image"]) | |
# plt.title(f"score: {round(score.item(),2)}\n prompt: {sample['caption'][0]}") | |
# plt.show() | |