image_modification / metrics /clip_similarity.py
timbrooks's picture
Add InstructPix2Pix
2afcb7e
raw
history blame
2.38 kB
from __future__ import annotations
import clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class ClipSimilarity(nn.Module):
def __init__(self, name: str = "ViT-L/14"):
super().__init__()
assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip
self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224)
self.model, _ = clip.load(name, device="cpu", download_root="./")
self.model.eval().requires_grad_(False)
self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073)))
self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711)))
def encode_text(self, text: list[str]) -> torch.Tensor:
text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device)
text_features = self.model.encode_text(text)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
return text_features
def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1].
image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False)
image = image - rearrange(self.mean, "c -> 1 c 1 1")
image = image / rearrange(self.std, "c -> 1 c 1 1")
image_features = self.model.encode_image(image)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
return image_features
def forward(
self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
image_features_0 = self.encode_image(image_0)
image_features_1 = self.encode_image(image_1)
text_features_0 = self.encode_text(text_0)
text_features_1 = self.encode_text(text_1)
sim_0 = F.cosine_similarity(image_features_0, text_features_0)
sim_1 = F.cosine_similarity(image_features_1, text_features_1)
sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0)
sim_image = F.cosine_similarity(image_features_0, image_features_1)
return sim_0, sim_1, sim_direction, sim_image