# Authors: Hui Ren (rhfeiyang.github.io)

import os

import numpy as np
from torchvision import transforms
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Function
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from collections import OrderedDict
from transformers import BatchFeature
import clip
import copy
import lpips
from transformers import ViTImageProcessor, ViTModel

## CSD_CLIP
def convert_weights_float(model: nn.Module):
    """Convert applicable model parameters to fp32"""

    def _convert_weights_to_fp32(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.float()
            if l.bias is not None:
                l.bias.data = l.bias.data.float()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.float()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.float()

    model.apply(_convert_weights_to_fp32)

class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None


## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py
class ProjectionHead(nn.Module):
    def __init__(
            self,
            embedding_dim,
            projection_dim,
            dropout=0
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

def convert_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith("module."):
            k = k.replace("module.", "")
        new_state_dict[k] = v
    return new_state_dict
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.normal_(m.bias, std=1e-6)

class Metric(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_preprocess = None

    def load_image(self, image_path):
        with open(image_path, 'rb') as f:
            image = Image.open(f).convert("RGB")
        return image

    def load_image_path(self, image_path):
        if isinstance(image_path, str):
            # should be a image folder path
            images_file = os.listdir(image_path)
            images = [os.path.join(image_path, image) for image in images_file if
                      image.endswith(".jpg") or image.endswith(".png")]
        if isinstance(image_path[0], str):
            images = [self.load_image(image) for image in image_path]
        elif isinstance(image_path[0], np.ndarray):
            images = [Image.fromarray(image) for image in image_path]
        elif isinstance(image_path[0], Image.Image):
            images = image_path
        else:
            raise Exception("Invalid input")
        return images

    def preprocess_image(self, image, **kwargs):
        if (isinstance(image, str) and os.path.isdir(image)) or (isinstance(image, list) and (isinstance(image[0], Image.Image) or isinstance(image[0], np.ndarray) or os.path.isfile(image[0]))):
            input_data = self.load_image_path(image)
            input_data = [self.image_preprocess(image, **kwargs) for image in input_data]
            input_data = torch.stack(input_data)
        elif os.path.isfile(image):
            input_data = self.load_image(image)
            input_data = self.image_preprocess(input_data, **kwargs)
            input_data = input_data.unsqueeze(0)
        elif isinstance(image, torch.Tensor):
            raise Exception("Unsupported input")
        return input_data

class Clip_Basic_Metric(Metric):
    def __init__(self):
        super().__init__()
        self.tensor_preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            # transforms.rescale
            transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
        self.image_preprocess = transforms.Compose([
            transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

class Clip_metric(Clip_Basic_Metric):

    @torch.no_grad()
    def __init__(self, target_style_prompt: str=None, clip_model_name="openai/clip-vit-large-patch14", device="cuda",
                 bath_size=8, alpha=0.5):
        super().__init__()
        self.device = device
        self.alpha = alpha
        self.model = (CLIPModel.from_pretrained(clip_model_name)).to(device)
        self.processor = CLIPProcessor.from_pretrained(clip_model_name)
        self.tokenizer = self.processor.tokenizer
        self.image_processor = self.processor.image_processor
        # self.style_class_features = self.get_text_features(self.styles).cpu()
        self.style_class_features=[]
        # self.noise_prompt_features = self.get_text_features("Noise")
        self.model.eval()
        self.batch_size = bath_size
        if target_style_prompt is not None:
            self.ref_style_features = self.get_text_features(target_style_prompt)
        else:
            self.ref_style_features = None

        self.ref_image_style_prototype = None

    def get_text_features(self, text):
        prompt_encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
        prompt_features = self.model.get_text_features(**prompt_encoding).to(self.device)
        prompt_features = F.normalize(prompt_features, p=2, dim=-1)
        return prompt_features

    def get_image_features(self, images):
        # if isinstance(image, torch.Tensor):
        #     self.tensor_transform(image)
        # else:
        #     image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
        images = self.load_image_path(images)
        if isinstance(images, torch.Tensor):
            images = self.tensor_preprocess(images)
            data = {"pixel_values": images}
            image_features = BatchFeature(data=data, tensor_type="pt")
        else:
            image_features = self.image_processor(images, return_tensors="pt", padding=True).to(self.device,
                                                                                                non_blocking=True)

        image_features = self.model.get_image_features(**image_features).to(self.device)
        image_features = F.normalize(image_features, p=2, dim=-1)
        return image_features

    def img_text_similarity(self, image_features, text=None):
        if text is not None:
            prompt_feature = self.get_text_features(text)
            if isinstance(text, str):
                prompt_feature = prompt_feature.repeat(len(image_features), 1)
        else:
            prompt_feature = self.ref_style_features

        similarity_each = torch.einsum("nc, nc -> n", image_features, prompt_feature)
        return similarity_each

    def forward(self, output_imgs, prompt=None):
        image_features = self.get_image_features(output_imgs)
        # print(image_features)
        style_score = self.img_text_similarity(image_features.mean(dim=0, keepdim=True))
        if prompt is not None:
            content_score = self.img_text_similarity(image_features, prompt)

            score = self.alpha * style_score + (1 - self.alpha) * content_score
            return {"score": score, "style_score": style_score, "content_score": content_score}
        else:
            return {"style_score": style_score}

    def content_score(self, output_imgs, prompt):
        self.to(self.device)
        image_features = self.get_image_features(output_imgs)
        content_score_details = self.img_text_similarity(image_features, prompt)
        self.to("cpu")
        return {"CLIP_content_score": content_score_details.mean().cpu(), "CLIP_content_score_details": content_score_details.cpu()}


class CSD_CLIP(Clip_Basic_Metric):
    """backbone + projection head"""
    def __init__(self, name='vit_large',content_proj_head='default', ckpt_path = "data/weights/CSD-checkpoint.pth", device="cuda",
                 alpha=0.5, **kwargs):
        super(CSD_CLIP, self).__init__()
        self.alpha = alpha
        self.content_proj_head = content_proj_head
        self.device = device
        if name == 'vit_large':
            clipmodel, _ = clip.load("ViT-L/14")
            self.backbone = clipmodel.visual
            self.embedding_dim = 1024
        elif name == 'vit_base':
            clipmodel, _ = clip.load("ViT-B/16")
            self.backbone = clipmodel.visual
            self.embedding_dim = 768
            self.feat_dim = 512
        else:
            raise Exception('This model is not implemented')

        convert_weights_float(self.backbone)
        self.last_layer_style = copy.deepcopy(self.backbone.proj)
        if content_proj_head == 'custom':
            self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim)
            self.last_layer_content.apply(init_weights)

        else:
            self.last_layer_content = copy.deepcopy(self.backbone.proj)

        self.backbone.proj = None
        self.backbone.requires_grad_(False)
        self.last_layer_style.requires_grad_(False)
        self.last_layer_content.requires_grad_(False)
        self.backbone.eval()

        if ckpt_path is not None:
            self.load_ckpt(ckpt_path)
        self.to("cpu")

    def load_ckpt(self, ckpt_path):
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        state_dict = convert_state_dict(checkpoint['model_state_dict'])
        msg = self.load_state_dict(state_dict, strict=False)
        print(f"=> loaded CSD_CLIP checkpoint with msg {msg}")

    @property
    def dtype(self):
        return self.backbone.conv1.weight.dtype

    def get_image_features(self, input_data, get_style=True,get_content=False,feature_alpha=None):
        if isinstance(input_data, torch.Tensor):
            input_data = self.tensor_preprocess(input_data)
        elif (isinstance(input_data, str) and os.path.isdir(input_data)) or (isinstance(input_data, list) and (isinstance(input_data[0], Image.Image) or isinstance(input_data[0], np.ndarray) or os.path.isfile(input_data[0]))):
            input_data = self.load_image_path(input_data)
            input_data = [self.image_preprocess(image) for image in input_data]
            input_data = torch.stack(input_data)
        elif os.path.isfile(input_data):
            input_data = self.load_image(input_data)
            input_data = self.image_preprocess(input_data)
            input_data = input_data.unsqueeze(0)
        input_data = input_data.to(self.device)
        style_output = None

        feature = self.backbone(input_data)
        if get_style:
            style_output = feature @ self.last_layer_style
            # style_output = style_output.mean(dim=0)
            style_output = nn.functional.normalize(style_output, dim=-1, p=2)

        content_output=None
        if get_content:
            if feature_alpha is not None:
                reverse_feature = ReverseLayerF.apply(feature, feature_alpha)
            else:
                reverse_feature = feature
            # if alpha is not None:
            if self.content_proj_head == 'custom':
                content_output =  self.last_layer_content(reverse_feature)
            else:
                content_output = reverse_feature @ self.last_layer_content
            content_output = nn.functional.normalize(content_output, dim=-1, p=2)

        return feature, content_output, style_output


    @torch.no_grad()
    def define_ref_image_style_prototype(self, ref_image_path: str):
        self.to(self.device)
        _, _, self.ref_style_feature = self.get_image_features(ref_image_path)
        self.to("cpu")
        # self.ref_style_feature = self.ref_style_feature.mean(dim=0)
    @torch.no_grad()
    def forward(self, styled_data):
        self.to(self.device)
        # get_content_feature = original_data is not None
        _, content_output, style_output = self.get_image_features(styled_data, get_content=False)
        style_similarities = style_output @ self.ref_style_feature.T
        mean_style_similarities = style_similarities.mean(dim=-1)
        mean_style_similarity = mean_style_similarities.mean()

        max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
        max_style_similarity = max_style_similarities_v.mean()


        self.to("cpu")
        return {"CSD_similarity_mean": mean_style_similarity, "CSD_similarity_max": max_style_similarity, "CSD_similarity_mean_details": mean_style_similarities,
                "CSD_similarity_max_v_details": max_style_similarities_v, "CSD_similarity_max_id_details": max_style_similarities_id}

    def get_style_loss(self, styled_data):
        _, _, style_output = self.get_image_features(styled_data, get_style=True, get_content=False)
        style_similarity = (style_output @ self.ref_style_feature).mean()
        loss = 1 - style_similarity
        return loss.mean()

class LPIPS_metric(Metric):
    def __init__(self, type="vgg", device="cuda"):
        super(LPIPS_metric, self).__init__()
        self.lpips = lpips.LPIPS(net=type)
        self.device = device
        self.image_preprocess = transforms.Compose([
            transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.to("cpu")

    @torch.no_grad()
    def forward(self, img1, img2):
        self.to(self.device)
        differences = []
        for i in range(0, len(img1), 50):
            img1_batch = img1[i:i+50]
            img2_batch = img2[i:i+50]
            img1_batch = self.preprocess_image(img1_batch).to(self.device)
            img2_batch = self.preprocess_image(img2_batch).to(self.device)
            differences.append(self.lpips(img1_batch, img2_batch).squeeze())
        differences = torch.cat(differences)
        difference = differences.mean()
        # similarity = 1 - difference
        self.to("cpu")
        return {"LPIPS_content_difference": difference,  "LPIPS_content_difference_details": differences}

class Vit_metric(Metric):
    def __init__(self, device="cuda"):
        super(Vit_metric, self).__init__()
        self.device = device
        self.model = ViTModel.from_pretrained('facebook/dino-vitb8').eval()
        self.image_processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8')
        self.to("cpu")
    def get_image_features(self, images):
        # if isinstance(image, torch.Tensor):
        #     self.tensor_transform(image)
        # else:
        #     image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
        images = self.load_image_path(images)
        batch_size = 20
        all_image_features = []
        for i in range(0, len(images), batch_size):
            image_batch = images[i:i+batch_size]
            if isinstance(image_batch, torch.Tensor):
                image_batch = self.tensor_preprocess(image_batch)
                data = {"pixel_values": image_batch}
                image_processed = BatchFeature(data=data, tensor_type="pt")
            else:
                image_processed = self.image_processor(image_batch, return_tensors="pt").to(self.device)
            image_features = self.model(**image_processed).last_hidden_state.flatten(start_dim=1)
            image_features = F.normalize(image_features, p=2, dim=-1)
            all_image_features.append(image_features)
        all_image_features = torch.cat(all_image_features)
        return all_image_features

    @torch.no_grad()
    def content_metric(self, img1, img2):
        self.to(self.device)
        if not(isinstance(img1, torch.Tensor) and len(img1.shape) == 2):
            img1 = self.get_image_features(img1)
        if not(isinstance(img2, torch.Tensor) and len(img2.shape) == 2):
            img2 = self.get_image_features(img2)
        similarities = torch.einsum("nc, nc -> n", img1, img2)
        similarity = similarities.mean()
        # self.to("cpu")
        return {"Vit_content_similarity": similarity, "Vit_content_similarity_details": similarities}

    # style
    @torch.no_grad()
    def define_ref_image_style_prototype(self, ref_image_path: str):
        self.to(self.device)
        self.ref_style_feature = self.get_image_features(ref_image_path)
        self.to("cpu")
    @torch.no_grad()
    def style_metric(self, styled_data):
        self.to(self.device)
        if isinstance(styled_data, torch.Tensor) and len(styled_data.shape) == 2:
            style_output = styled_data
        else:
            style_output = self.get_image_features(styled_data)
        style_similarities = style_output @ self.ref_style_feature.T
        mean_style_similarities = style_similarities.mean(dim=-1)
        mean_style_similarity = mean_style_similarities.mean()

        max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
        max_style_similarity = max_style_similarities_v.mean()

        # self.to("cpu")
        return {"Vit_style_similarity_mean": mean_style_similarity, "Vit_style_similarity_max": max_style_similarity, "Vit_style_similarity_mean_details": mean_style_similarities,
                "Vit_style_similarity_max_v_details": max_style_similarities_v, "Vit_style_similarity_max_id_details": max_style_similarities_id}
    @torch.no_grad()
    def forward(self, styled_data, original_data=None):
        self.to(self.device)
        styled_features = self.get_image_features(styled_data)
        ret ={}
        if original_data is not None:
            content_metric = self.content_metric(styled_features, original_data)
            ret["Vit_content"] = content_metric
        style_metric = self.style_metric(styled_features)
        ret["Vit_style"] = style_metric
        self.to("cpu")
        return ret



class StyleContentMetric(nn.Module):
    def __init__(self, style_ref_image_folder, device="cuda"):
        super(StyleContentMetric, self).__init__()
        self.device = device
        self.clip_style_metric = CSD_CLIP(device=device)
        self.ref_image_file = os.listdir(style_ref_image_folder)
        self.ref_image_file = [i for i in self.ref_image_file if i.endswith(".jpg") or i.endswith(".png")]
        self.ref_image_file.sort()
        self.ref_image_file = np.array(self.ref_image_file)
        ref_image_file_path = [os.path.join(style_ref_image_folder, i) for i in self.ref_image_file]

        self.clip_style_metric.define_ref_image_style_prototype(ref_image_file_path)
        self.vit_metric = Vit_metric(device=device)
        self.vit_metric.define_ref_image_style_prototype(ref_image_file_path)
        self.lpips_metric = LPIPS_metric(device=device)

        self.clip_content_metric = Clip_metric(alpha=0, target_style_prompt=None)

        self.to("cpu")

    def forward(self, styled_data, original_data=None, content_caption=None):
        ret ={}
        csd_score = self.clip_style_metric(styled_data)
        csd_score["max_query"] = self.ref_image_file[csd_score["CSD_similarity_max_id_details"].cpu()].tolist()
        torch.cuda.empty_cache()
        ret["Style_CSD"] = csd_score
        vit_score = self.vit_metric(styled_data, original_data)
        torch.cuda.empty_cache()
        vit_style = vit_score["Vit_style"]
        vit_style["max_query"] = self.ref_image_file[vit_style["Vit_style_similarity_max_id_details"].cpu()].tolist()
        ret["Style_VIT"] = vit_style

        if original_data is not None:
            vit_content = vit_score["Vit_content"]
            ret["Content_VIT"] = vit_content
            lpips_score = self.lpips_metric(styled_data, original_data)
            torch.cuda.empty_cache()
            ret["Content_LPIPS"] = lpips_score

        if content_caption is not None:
            clip_content = self.clip_content_metric.content_score(styled_data, content_caption)
            ret["Content_CLIP"] = clip_content
            torch.cuda.empty_cache()

        for type_key, type_value in ret.items():
            for key, value in type_value.items():
                if isinstance(value, torch.Tensor):
                    if value.numel() == 1:
                        ret[type_key][key] = round(value.item(), 4)
                    else:
                        ret[type_key][key] = value.tolist()
                        ret[type_key][key] = [round(v, 4) for v in ret[type_key][key]]

        self.to("cpu")
        ret["ref_image_file"] = self.ref_image_file.tolist()
        return ret


if __name__ == "__main__":
    with torch.no_grad():
        metric = StyleContentMetric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Art_styles/camille-pissarro/impressionism/split_5/paintings")
        score = metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500",
                       "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
        print(score)



        lpips = LPIPS_metric()
        score = lpips("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
                      "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")

        print("lpips", score)


        clip_metric = CSD_CLIP()
        clip_metric.define_ref_image_style_prototype(
            "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")

        score = clip_metric(
            "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
        print("subset3-subset3_sd14_converted", score)

        score = clip_metric(
            "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
        print("subset3-photo", score)



        score = clip_metric(
            "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
        print("subset3-subset1", score)

        score = clip_metric(
            "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/andy-warhol/pop_art/subset1/paintings")
        print("subset3-andy", score)
        # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", "A painting")

        # print("subset3",score)
        # score_subset2 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset2/paintings")
        # print("subset2",score_subset2)
        # score_subset3 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
        # print("subset3",score_subset3)
        #
        # score_subset3_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
        # print("subset3-subset3_sd14_converted" , score_subset3_converted)
        #
        # score_subset3_coco_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/coco_converted_photo/500")
        # print("subset3-subset3_coco_converted" , score_subset3_coco_converted)
        #
        # clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500")
        # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
        # print("photo500_1-sketch" ,score)
        #
        # clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
        # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500_new")
        # print("photo500_1-photo500_2" ,score)
        # from custom_datasets.imagepair import ImageSet
        # import matplotlib.pyplot as plt
        # dataset = ImageSet(folder = "/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
        #                    caption_path="/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/captions",
        #                     keep_in_mem=False)
        # for sample in dataset:
        #     score = clip_metric.content_score(sample["image"], sample["caption"][0])
        #     plt.imshow(sample["image"])
        #     plt.title(f"score: {round(score.item(),2)}\n prompt: {sample['caption'][0]}")
        #     plt.show()