from __future__ import annotations import clip import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class ClipSimilarity(nn.Module): def __init__(self, name: str = "ViT-L/14"): super().__init__() assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224) self.model, _ = clip.load(name, device="cpu", download_root="./") self.model.eval().requires_grad_(False) self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073))) self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711))) def encode_text(self, text: list[str]) -> torch.Tensor: text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device) text_features = self.model.encode_text(text) text_features = text_features / text_features.norm(dim=1, keepdim=True) return text_features def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1]. image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False) image = image - rearrange(self.mean, "c -> 1 c 1 1") image = image / rearrange(self.std, "c -> 1 c 1 1") image_features = self.model.encode_image(image) image_features = image_features / image_features.norm(dim=1, keepdim=True) return image_features def forward( self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: image_features_0 = self.encode_image(image_0) image_features_1 = self.encode_image(image_1) text_features_0 = self.encode_text(text_0) text_features_1 = self.encode_text(text_1) sim_0 = F.cosine_similarity(image_features_0, text_features_0) sim_1 = F.cosine_similarity(image_features_1, text_features_1) sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0) sim_image = F.cosine_similarity(image_features_0, image_features_1) return sim_0, sim_1, sim_direction, sim_image