rhfeiyang's picture
Upload folder using huggingface_hub
262b155 verified
raw
history blame
27.2 kB
# Authors: Hui Ren (rhfeiyang.github.io)
import os
import numpy as np
from torchvision import transforms
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Function
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from collections import OrderedDict
from transformers import BatchFeature
import clip
import copy
import lpips
from transformers import ViTImageProcessor, ViTModel
## CSD_CLIP
def convert_weights_float(model: nn.Module):
"""Convert applicable model parameters to fp32"""
def _convert_weights_to_fp32(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.float()
if l.bias is not None:
l.bias.data = l.bias.data.float()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.float()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.float()
model.apply(_convert_weights_to_fp32)
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim,
dropout=0
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
def convert_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
k = k.replace("module.", "")
new_state_dict[k] = v
return new_state_dict
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
class Metric(nn.Module):
def __init__(self):
super().__init__()
self.image_preprocess = None
def load_image(self, image_path):
with open(image_path, 'rb') as f:
image = Image.open(f).convert("RGB")
return image
def load_image_path(self, image_path):
if isinstance(image_path, str):
# should be a image folder path
images_file = os.listdir(image_path)
images = [os.path.join(image_path, image) for image in images_file if
image.endswith(".jpg") or image.endswith(".png")]
if isinstance(image_path[0], str):
images = [self.load_image(image) for image in image_path]
elif isinstance(image_path[0], np.ndarray):
images = [Image.fromarray(image) for image in image_path]
elif isinstance(image_path[0], Image.Image):
images = image_path
else:
raise Exception("Invalid input")
return images
def preprocess_image(self, image, **kwargs):
if (isinstance(image, str) and os.path.isdir(image)) or (isinstance(image, list) and (isinstance(image[0], Image.Image) or isinstance(image[0], np.ndarray) or os.path.isfile(image[0]))):
input_data = self.load_image_path(image)
input_data = [self.image_preprocess(image, **kwargs) for image in input_data]
input_data = torch.stack(input_data)
elif os.path.isfile(image):
input_data = self.load_image(image)
input_data = self.image_preprocess(input_data, **kwargs)
input_data = input_data.unsqueeze(0)
elif isinstance(image, torch.Tensor):
raise Exception("Unsupported input")
return input_data
class Clip_Basic_Metric(Metric):
def __init__(self):
super().__init__()
self.tensor_preprocess = transforms.Compose([
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
# transforms.rescale
transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.image_preprocess = transforms.Compose([
transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
class Clip_metric(Clip_Basic_Metric):
@torch.no_grad()
def __init__(self, target_style_prompt: str=None, clip_model_name="openai/clip-vit-large-patch14", device="cuda",
bath_size=8, alpha=0.5):
super().__init__()
self.device = device
self.alpha = alpha
self.model = (CLIPModel.from_pretrained(clip_model_name)).to(device)
self.processor = CLIPProcessor.from_pretrained(clip_model_name)
self.tokenizer = self.processor.tokenizer
self.image_processor = self.processor.image_processor
# self.style_class_features = self.get_text_features(self.styles).cpu()
self.style_class_features=[]
# self.noise_prompt_features = self.get_text_features("Noise")
self.model.eval()
self.batch_size = bath_size
if target_style_prompt is not None:
self.ref_style_features = self.get_text_features(target_style_prompt)
else:
self.ref_style_features = None
self.ref_image_style_prototype = None
def get_text_features(self, text):
prompt_encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
prompt_features = self.model.get_text_features(**prompt_encoding).to(self.device)
prompt_features = F.normalize(prompt_features, p=2, dim=-1)
return prompt_features
def get_image_features(self, images):
# if isinstance(image, torch.Tensor):
# self.tensor_transform(image)
# else:
# image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
images = self.load_image_path(images)
if isinstance(images, torch.Tensor):
images = self.tensor_preprocess(images)
data = {"pixel_values": images}
image_features = BatchFeature(data=data, tensor_type="pt")
else:
image_features = self.image_processor(images, return_tensors="pt", padding=True).to(self.device,
non_blocking=True)
image_features = self.model.get_image_features(**image_features).to(self.device)
image_features = F.normalize(image_features, p=2, dim=-1)
return image_features
def img_text_similarity(self, image_features, text=None):
if text is not None:
prompt_feature = self.get_text_features(text)
if isinstance(text, str):
prompt_feature = prompt_feature.repeat(len(image_features), 1)
else:
prompt_feature = self.ref_style_features
similarity_each = torch.einsum("nc, nc -> n", image_features, prompt_feature)
return similarity_each
def forward(self, output_imgs, prompt=None):
image_features = self.get_image_features(output_imgs)
# print(image_features)
style_score = self.img_text_similarity(image_features.mean(dim=0, keepdim=True))
if prompt is not None:
content_score = self.img_text_similarity(image_features, prompt)
score = self.alpha * style_score + (1 - self.alpha) * content_score
return {"score": score, "style_score": style_score, "content_score": content_score}
else:
return {"style_score": style_score}
def content_score(self, output_imgs, prompt):
self.to(self.device)
image_features = self.get_image_features(output_imgs)
content_score_details = self.img_text_similarity(image_features, prompt)
self.to("cpu")
return {"CLIP_content_score": content_score_details.mean().cpu(), "CLIP_content_score_details": content_score_details.cpu()}
class CSD_CLIP(Clip_Basic_Metric):
"""backbone + projection head"""
def __init__(self, name='vit_large',content_proj_head='default', ckpt_path = "data/weights/CSD-checkpoint.pth", device="cuda",
alpha=0.5, **kwargs):
super(CSD_CLIP, self).__init__()
self.alpha = alpha
self.content_proj_head = content_proj_head
self.device = device
if name == 'vit_large':
clipmodel, _ = clip.load("ViT-L/14")
self.backbone = clipmodel.visual
self.embedding_dim = 1024
elif name == 'vit_base':
clipmodel, _ = clip.load("ViT-B/16")
self.backbone = clipmodel.visual
self.embedding_dim = 768
self.feat_dim = 512
else:
raise Exception('This model is not implemented')
convert_weights_float(self.backbone)
self.last_layer_style = copy.deepcopy(self.backbone.proj)
if content_proj_head == 'custom':
self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim)
self.last_layer_content.apply(init_weights)
else:
self.last_layer_content = copy.deepcopy(self.backbone.proj)
self.backbone.proj = None
self.backbone.requires_grad_(False)
self.last_layer_style.requires_grad_(False)
self.last_layer_content.requires_grad_(False)
self.backbone.eval()
if ckpt_path is not None:
self.load_ckpt(ckpt_path)
self.to("cpu")
def load_ckpt(self, ckpt_path):
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = convert_state_dict(checkpoint['model_state_dict'])
msg = self.load_state_dict(state_dict, strict=False)
print(f"=> loaded CSD_CLIP checkpoint with msg {msg}")
@property
def dtype(self):
return self.backbone.conv1.weight.dtype
def get_image_features(self, input_data, get_style=True,get_content=False,feature_alpha=None):
if isinstance(input_data, torch.Tensor):
input_data = self.tensor_preprocess(input_data)
elif (isinstance(input_data, str) and os.path.isdir(input_data)) or (isinstance(input_data, list) and (isinstance(input_data[0], Image.Image) or isinstance(input_data[0], np.ndarray) or os.path.isfile(input_data[0]))):
input_data = self.load_image_path(input_data)
input_data = [self.image_preprocess(image) for image in input_data]
input_data = torch.stack(input_data)
elif os.path.isfile(input_data):
input_data = self.load_image(input_data)
input_data = self.image_preprocess(input_data)
input_data = input_data.unsqueeze(0)
input_data = input_data.to(self.device)
style_output = None
feature = self.backbone(input_data)
if get_style:
style_output = feature @ self.last_layer_style
# style_output = style_output.mean(dim=0)
style_output = nn.functional.normalize(style_output, dim=-1, p=2)
content_output=None
if get_content:
if feature_alpha is not None:
reverse_feature = ReverseLayerF.apply(feature, feature_alpha)
else:
reverse_feature = feature
# if alpha is not None:
if self.content_proj_head == 'custom':
content_output = self.last_layer_content(reverse_feature)
else:
content_output = reverse_feature @ self.last_layer_content
content_output = nn.functional.normalize(content_output, dim=-1, p=2)
return feature, content_output, style_output
@torch.no_grad()
def define_ref_image_style_prototype(self, ref_image_path: str):
self.to(self.device)
_, _, self.ref_style_feature = self.get_image_features(ref_image_path)
self.to("cpu")
# self.ref_style_feature = self.ref_style_feature.mean(dim=0)
@torch.no_grad()
def forward(self, styled_data):
self.to(self.device)
# get_content_feature = original_data is not None
_, content_output, style_output = self.get_image_features(styled_data, get_content=False)
style_similarities = style_output @ self.ref_style_feature.T
mean_style_similarities = style_similarities.mean(dim=-1)
mean_style_similarity = mean_style_similarities.mean()
max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
max_style_similarity = max_style_similarities_v.mean()
self.to("cpu")
return {"CSD_similarity_mean": mean_style_similarity, "CSD_similarity_max": max_style_similarity, "CSD_similarity_mean_details": mean_style_similarities,
"CSD_similarity_max_v_details": max_style_similarities_v, "CSD_similarity_max_id_details": max_style_similarities_id}
def get_style_loss(self, styled_data):
_, _, style_output = self.get_image_features(styled_data, get_style=True, get_content=False)
style_similarity = (style_output @ self.ref_style_feature).mean()
loss = 1 - style_similarity
return loss.mean()
class LPIPS_metric(Metric):
def __init__(self, type="vgg", device="cuda"):
super(LPIPS_metric, self).__init__()
self.lpips = lpips.LPIPS(net=type)
self.device = device
self.image_preprocess = transforms.Compose([
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.to("cpu")
@torch.no_grad()
def forward(self, img1, img2):
self.to(self.device)
differences = []
for i in range(0, len(img1), 50):
img1_batch = img1[i:i+50]
img2_batch = img2[i:i+50]
img1_batch = self.preprocess_image(img1_batch).to(self.device)
img2_batch = self.preprocess_image(img2_batch).to(self.device)
differences.append(self.lpips(img1_batch, img2_batch).squeeze())
differences = torch.cat(differences)
difference = differences.mean()
# similarity = 1 - difference
self.to("cpu")
return {"LPIPS_content_difference": difference, "LPIPS_content_difference_details": differences}
class Vit_metric(Metric):
def __init__(self, device="cuda"):
super(Vit_metric, self).__init__()
self.device = device
self.model = ViTModel.from_pretrained('facebook/dino-vitb8').eval()
self.image_processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8')
self.to("cpu")
def get_image_features(self, images):
# if isinstance(image, torch.Tensor):
# self.tensor_transform(image)
# else:
# image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
images = self.load_image_path(images)
batch_size = 20
all_image_features = []
for i in range(0, len(images), batch_size):
image_batch = images[i:i+batch_size]
if isinstance(image_batch, torch.Tensor):
image_batch = self.tensor_preprocess(image_batch)
data = {"pixel_values": image_batch}
image_processed = BatchFeature(data=data, tensor_type="pt")
else:
image_processed = self.image_processor(image_batch, return_tensors="pt").to(self.device)
image_features = self.model(**image_processed).last_hidden_state.flatten(start_dim=1)
image_features = F.normalize(image_features, p=2, dim=-1)
all_image_features.append(image_features)
all_image_features = torch.cat(all_image_features)
return all_image_features
@torch.no_grad()
def content_metric(self, img1, img2):
self.to(self.device)
if not(isinstance(img1, torch.Tensor) and len(img1.shape) == 2):
img1 = self.get_image_features(img1)
if not(isinstance(img2, torch.Tensor) and len(img2.shape) == 2):
img2 = self.get_image_features(img2)
similarities = torch.einsum("nc, nc -> n", img1, img2)
similarity = similarities.mean()
# self.to("cpu")
return {"Vit_content_similarity": similarity, "Vit_content_similarity_details": similarities}
# style
@torch.no_grad()
def define_ref_image_style_prototype(self, ref_image_path: str):
self.to(self.device)
self.ref_style_feature = self.get_image_features(ref_image_path)
self.to("cpu")
@torch.no_grad()
def style_metric(self, styled_data):
self.to(self.device)
if isinstance(styled_data, torch.Tensor) and len(styled_data.shape) == 2:
style_output = styled_data
else:
style_output = self.get_image_features(styled_data)
style_similarities = style_output @ self.ref_style_feature.T
mean_style_similarities = style_similarities.mean(dim=-1)
mean_style_similarity = mean_style_similarities.mean()
max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
max_style_similarity = max_style_similarities_v.mean()
# self.to("cpu")
return {"Vit_style_similarity_mean": mean_style_similarity, "Vit_style_similarity_max": max_style_similarity, "Vit_style_similarity_mean_details": mean_style_similarities,
"Vit_style_similarity_max_v_details": max_style_similarities_v, "Vit_style_similarity_max_id_details": max_style_similarities_id}
@torch.no_grad()
def forward(self, styled_data, original_data=None):
self.to(self.device)
styled_features = self.get_image_features(styled_data)
ret ={}
if original_data is not None:
content_metric = self.content_metric(styled_features, original_data)
ret["Vit_content"] = content_metric
style_metric = self.style_metric(styled_features)
ret["Vit_style"] = style_metric
self.to("cpu")
return ret
class StyleContentMetric(nn.Module):
def __init__(self, style_ref_image_folder, device="cuda"):
super(StyleContentMetric, self).__init__()
self.device = device
self.clip_style_metric = CSD_CLIP(device=device)
self.ref_image_file = os.listdir(style_ref_image_folder)
self.ref_image_file = [i for i in self.ref_image_file if i.endswith(".jpg") or i.endswith(".png")]
self.ref_image_file.sort()
self.ref_image_file = np.array(self.ref_image_file)
ref_image_file_path = [os.path.join(style_ref_image_folder, i) for i in self.ref_image_file]
self.clip_style_metric.define_ref_image_style_prototype(ref_image_file_path)
self.vit_metric = Vit_metric(device=device)
self.vit_metric.define_ref_image_style_prototype(ref_image_file_path)
self.lpips_metric = LPIPS_metric(device=device)
self.clip_content_metric = Clip_metric(alpha=0, target_style_prompt=None)
self.to("cpu")
def forward(self, styled_data, original_data=None, content_caption=None):
ret ={}
csd_score = self.clip_style_metric(styled_data)
csd_score["max_query"] = self.ref_image_file[csd_score["CSD_similarity_max_id_details"].cpu()].tolist()
torch.cuda.empty_cache()
ret["Style_CSD"] = csd_score
vit_score = self.vit_metric(styled_data, original_data)
torch.cuda.empty_cache()
vit_style = vit_score["Vit_style"]
vit_style["max_query"] = self.ref_image_file[vit_style["Vit_style_similarity_max_id_details"].cpu()].tolist()
ret["Style_VIT"] = vit_style
if original_data is not None:
vit_content = vit_score["Vit_content"]
ret["Content_VIT"] = vit_content
lpips_score = self.lpips_metric(styled_data, original_data)
torch.cuda.empty_cache()
ret["Content_LPIPS"] = lpips_score
if content_caption is not None:
clip_content = self.clip_content_metric.content_score(styled_data, content_caption)
ret["Content_CLIP"] = clip_content
torch.cuda.empty_cache()
for type_key, type_value in ret.items():
for key, value in type_value.items():
if isinstance(value, torch.Tensor):
if value.numel() == 1:
ret[type_key][key] = round(value.item(), 4)
else:
ret[type_key][key] = value.tolist()
ret[type_key][key] = [round(v, 4) for v in ret[type_key][key]]
self.to("cpu")
ret["ref_image_file"] = self.ref_image_file.tolist()
return ret
if __name__ == "__main__":
with torch.no_grad():
metric = StyleContentMetric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Art_styles/camille-pissarro/impressionism/split_5/paintings")
score = metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500",
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
print(score)
lpips = LPIPS_metric()
score = lpips("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
print("lpips", score)
clip_metric = CSD_CLIP()
clip_metric.define_ref_image_style_prototype(
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
score = clip_metric(
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
print("subset3-subset3_sd14_converted", score)
score = clip_metric(
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
print("subset3-photo", score)
score = clip_metric(
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
print("subset3-subset1", score)
score = clip_metric(
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/andy-warhol/pop_art/subset1/paintings")
print("subset3-andy", score)
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", "A painting")
# print("subset3",score)
# score_subset2 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset2/paintings")
# print("subset2",score_subset2)
# score_subset3 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
# print("subset3",score_subset3)
#
# score_subset3_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
# print("subset3-subset3_sd14_converted" , score_subset3_converted)
#
# score_subset3_coco_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/coco_converted_photo/500")
# print("subset3-subset3_coco_converted" , score_subset3_coco_converted)
#
# clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500")
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
# print("photo500_1-sketch" ,score)
#
# clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500_new")
# print("photo500_1-photo500_2" ,score)
# from custom_datasets.imagepair import ImageSet
# import matplotlib.pyplot as plt
# dataset = ImageSet(folder = "/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
# caption_path="/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/captions",
# keep_in_mem=False)
# for sample in dataset:
# score = clip_metric.content_score(sample["image"], sample["caption"][0])
# plt.imshow(sample["image"])
# plt.title(f"score: {round(score.item(),2)}\n prompt: {sample['caption'][0]}")
# plt.show()