# Based on a file from https://github.com/rinongal/StyleGAN-nada. # ========================================================================================== # # Adobe’s modifications are Copyright 2023 Adobe Research. All rights reserved. # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit # LICENSE.md. # # ========================================================================================== import clip import torch from torchvision.transforms import transforms import numpy as np from PIL import Image from expansion_utils.text_templates import imagenet_templates, part_templates # TODO: get rid of unused stuff in this class class CLIPLoss(torch.nn.Module): def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0., lambda_manifold=0., lambda_texture=0., patch_loss_type='mae', direction_loss_type='cosine', clip_model='ViT-B/32'): super(CLIPLoss, self).__init__() self.device = device self.model, clip_preprocess = clip.load(clip_model, device=self.device) self.clip_preprocess = clip_preprocess self.preprocess = transforms.Compose( [transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions clip_preprocess.transforms[4:]) # + skip convert PIL to tensor self.target_directions_cache = {} self.patch_text_directions = None self.patch_loss = DirectionLoss(patch_loss_type) self.direction_loss = DirectionLoss(direction_loss_type) self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2) self.lambda_global = lambda_global self.lambda_patch = lambda_patch self.lambda_direction = lambda_direction self.lambda_manifold = lambda_manifold self.lambda_texture = lambda_texture self.src_text_features = None self.target_text_features = None self.angle_loss = torch.nn.L1Loss() self.model_cnn, preprocess_cnn = clip.load("RN50", device=self.device) self.preprocess_cnn = transforms.Compose( [transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor self.model.requires_grad_(False) self.model_cnn.requires_grad_(False) self.texture_loss = torch.nn.MSELoss() def tokenize(self, strings: list): return clip.tokenize(strings).to(self.device) def encode_text(self, tokens: list) -> torch.Tensor: return self.model.encode_text(tokens) def encode_images(self, images: torch.Tensor) -> torch.Tensor: images = self.preprocess(images).to(self.device) return self.model.encode_image(images) def encode_images_with_cnn(self, images: torch.Tensor) -> torch.Tensor: images = self.preprocess_cnn(images).to(self.device) return self.model_cnn.encode_image(images) def distance_with_templates(self, img: torch.Tensor, class_str: str, templates=imagenet_templates) -> torch.Tensor: text_features = self.get_text_features(class_str, templates) image_features = self.get_image_features(img) similarity = image_features @ text_features.T return 1. - similarity def get_text_features(self, class_str: str, templates=imagenet_templates, norm: bool = True) -> torch.Tensor: template_text = self.compose_text_with_templates(class_str, templates) tokens = clip.tokenize(template_text).to(self.device) text_features = self.encode_text(tokens).detach() if norm: text_features /= text_features.norm(dim=-1, keepdim=True) return text_features def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor: image_features = self.encode_images(img) if norm: image_features /= image_features.clone().norm(dim=-1, keepdim=True) return image_features def compute_text_direction(self, source_class: str, target_class: str) -> torch.Tensor: with torch.no_grad(): source_features = self.get_text_features(source_class) target_features = self.get_text_features(target_class) text_direction = (target_features - source_features).mean(axis=0, keepdim=True) text_direction /= text_direction.norm(dim=-1, keepdim=True) return text_direction def compute_img2img_direction(self, source_images: torch.Tensor, target_images: list) -> torch.Tensor: with torch.no_grad(): src_encoding = self.get_image_features(source_images) src_encoding = src_encoding.mean(dim=0, keepdim=True) target_encodings = [] for target_img in target_images: preprocessed = self.clip_preprocess(Image.open(target_img)).unsqueeze(0).to(self.device) encoding = self.model.encode_image(preprocessed) encoding /= encoding.norm(dim=-1, keepdim=True) target_encodings.append(encoding) target_encoding = torch.cat(target_encodings, axis=0) target_encoding = target_encoding.mean(dim=0, keepdim=True) direction = target_encoding - src_encoding direction /= direction.norm(dim=-1, keepdim=True) return direction def set_text_features(self, source_class: str, target_class: str) -> None: source_features = self.get_text_features(source_class).mean(axis=0, keepdim=True) self.src_text_features = source_features / source_features.norm(dim=-1, keepdim=True) target_features = self.get_text_features(target_class).mean(axis=0, keepdim=True) self.target_text_features = target_features / target_features.norm(dim=-1, keepdim=True) def clip_angle_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: if self.src_text_features is None: self.set_text_features(source_class, target_class) cos_text_angle = self.target_text_features @ self.src_text_features.T text_angle = torch.acos(cos_text_angle) src_img_features = self.get_image_features(src_img).unsqueeze(2) target_img_features = self.get_image_features(target_img).unsqueeze(1) cos_img_angle = torch.clamp(target_img_features @ src_img_features, min=-1.0, max=1.0) img_angle = torch.acos(cos_img_angle) text_angle = text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1) cos_text_angle = cos_text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1) return self.angle_loss(cos_img_angle, cos_text_angle) def compose_text_with_templates(self, text: str, templates=imagenet_templates) -> list: return [template.format(text) for template in templates] def clip_directional_loss(self, src_img: torch.Tensor, source_classes: np.ndarray, target_img: torch.Tensor, target_classes: np.ndarray) -> torch.Tensor: target_directions = [] for key in zip(source_classes, target_classes): if key not in self.target_directions_cache.keys(): new_direction = self.compute_text_direction(*key) self.target_directions_cache[key] = new_direction target_directions.append(self.target_directions_cache[key]) target_directions = torch.cat(target_directions) src_encoding = self.get_image_features(src_img) target_encoding = self.get_image_features(target_img) edit_direction = (target_encoding - src_encoding) if edit_direction.sum() == 0: target_encoding = self.get_image_features(target_img + 1e-6) edit_direction = (target_encoding - src_encoding) edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True)) return self.direction_loss(edit_direction, target_directions).sum() def global_clip_loss(self, img: torch.Tensor, text) -> torch.Tensor: if not isinstance(text, list): text = [text] tokens = clip.tokenize(text).to(self.device) image = self.preprocess(img) logits_per_image, _ = self.model(image, tokens) return (1. - logits_per_image / 100).mean() def random_patch_centers(self, img_shape, num_patches, size): batch_size, channels, height, width = img_shape half_size = size // 2 patch_centers = np.concatenate( [np.random.randint(half_size, width - half_size, size=(batch_size * num_patches, 1)), np.random.randint(half_size, height - half_size, size=(batch_size * num_patches, 1))], axis=1) return patch_centers def generate_patches(self, img: torch.Tensor, patch_centers, size): batch_size = img.shape[0] num_patches = len(patch_centers) // batch_size half_size = size // 2 patches = [] for batch_idx in range(batch_size): for patch_idx in range(num_patches): center_x = patch_centers[batch_idx * num_patches + patch_idx][0] center_y = patch_centers[batch_idx * num_patches + patch_idx][1] patch = img[batch_idx:batch_idx + 1, :, center_y - half_size:center_y + half_size, center_x - half_size:center_x + half_size] patches.append(patch) patches = torch.cat(patches, axis=0) return patches def patch_scores(self, img: torch.Tensor, class_str: str, patch_centers, patch_size: int) -> torch.Tensor: parts = self.compose_text_with_templates(class_str, part_templates) tokens = clip.tokenize(parts).to(self.device) text_features = self.encode_text(tokens).detach() patches = self.generate_patches(img, patch_centers, patch_size) image_features = self.get_image_features(patches) similarity = image_features @ text_features.T return similarity def clip_patch_similarity(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: patch_size = 196 # TODO remove magic number patch_centers = self.random_patch_centers(src_img.shape, 4, patch_size) # TODO remove magic number src_scores = self.patch_scores(src_img, source_class, patch_centers, patch_size) target_scores = self.patch_scores(target_img, target_class, patch_centers, patch_size) return self.patch_loss(src_scores, target_scores) def patch_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: if self.patch_text_directions is None: src_part_classes = self.compose_text_with_templates(source_class, part_templates) target_part_classes = self.compose_text_with_templates(target_class, part_templates) parts_classes = list(zip(src_part_classes, target_part_classes)) self.patch_text_directions = torch.cat( [self.compute_text_direction(pair[0], pair[1]) for pair in parts_classes], dim=0) patch_size = 510 # TODO remove magic numbers patch_centers = self.random_patch_centers(src_img.shape, 1, patch_size) patches = self.generate_patches(src_img, patch_centers, patch_size) src_features = self.get_image_features(patches) patches = self.generate_patches(target_img, patch_centers, patch_size) target_features = self.get_image_features(patches) edit_direction = (target_features - src_features) edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True) cosine_dists = 1. - self.patch_direction_loss(edit_direction.unsqueeze(1), self.patch_text_directions.unsqueeze(0)) patch_class_scores = cosine_dists * (edit_direction @ self.patch_text_directions.T).softmax(dim=-1) return patch_class_scores.mean() def cnn_feature_loss(self, src_img: torch.Tensor, target_img: torch.Tensor) -> torch.Tensor: src_features = self.encode_images_with_cnn(src_img) target_features = self.encode_images_with_cnn(target_img) return self.texture_loss(src_features, target_features) def forward(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str, texture_image: torch.Tensor = None): clip_loss = 0.0 if self.lambda_global: clip_loss += self.lambda_global * self.global_clip_loss(target_img, [f"a {target_class}"]) if self.lambda_patch: # IMO Same directional loss but run on patches clip_loss += self.lambda_patch * self.patch_directional_loss(src_img, source_class, target_img, target_class) if self.lambda_direction: # The directional loss used in the paper clip_loss += self.lambda_direction * self.clip_directional_loss(src_img, source_class, target_img, target_class) if self.lambda_manifold: # Compute angels of text and image directions and do L1 clip_loss += self.lambda_manifold * self.clip_angle_loss(src_img, source_class, target_img, target_class) if self.lambda_texture and (texture_image is not None): # L2 on features extracted by a CNN clip_loss += self.lambda_texture * self.cnn_feature_loss(texture_image, target_img) return clip_loss class DirectionLoss(torch.nn.Module): def __init__(self, loss_type='mse'): super(DirectionLoss, self).__init__() self.loss_type = loss_type self.loss_func = { 'mse': torch.nn.MSELoss, 'cosine': torch.nn.CosineSimilarity, 'mae': torch.nn.L1Loss }[loss_type]() def forward(self, x, y): if self.loss_type == "cosine": return 1. - self.loss_func(x, y) return self.loss_func(x, y)