''' Code adapted from Stitch it in Time by Tzaban et al. https://github.com/rotemtzaban/STIT ''' import numpy as np import torch from tqdm import tqdm from pathlib import Path import os import clip imagenet_templates = [ 'a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', ] CONV_CODE_INDICES = [(0, 512), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \ [(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)] def zeroshot_classifier(model, classnames, templates, device): with torch.no_grad(): zeroshot_weights = [] for classname in tqdm(classnames): texts = [template.format(classname) for template in templates] # format with class texts = clip.tokenize(texts).to(device) # tokenize class_embeddings = model.encode_text(texts) # embed with text encoder class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) return zeroshot_weights def expand_to_full_dim(partial_tensor): full_dim_tensor = torch.zeros(size=(1, 9088)) start_idx = 0 for conv_start, conv_end in CONV_CODE_INDICES: length = conv_end - conv_start full_dim_tensor[:, conv_start:conv_end] = partial_tensor[start_idx:start_idx + length] start_idx += length return full_dim_tensor def get_direction(neutral_class, target_class, beta, di, clip_model=None): device = "cuda" if torch.cuda.is_available() else "cpu" if clip_model is None: clip_model, _ = clip.load("ViT-B/32", device=device) class_names = [neutral_class, target_class] class_weights = zeroshot_classifier(clip_model, class_names, imagenet_templates, device) dt = class_weights[:, 1] - class_weights[:, 0] dt = dt / dt.norm() dt = dt.float() di = di.float() relevance = di @ dt mask = relevance.abs() > beta direction = relevance * mask direction_max = direction.abs().max() if direction_max > 0: direction = direction / direction_max else: raise ValueError(f'Beta value {beta} is too high for mapping from {neutral_class} to {target_class},' f' try setting it to a lower value') return direction def style_tensor_to_style_dict(style_tensor, refernce_generator): style_layers = refernce_generator.modulation_layers style_dict = {} for layer_idx, layer in enumerate(style_layers): style_dict[layer] = style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] return style_dict def style_dict_to_style_tensor(style_dict, reference_generator): style_layers = reference_generator.modulation_layers style_tensor = torch.zeros(size=(1, 9088)) for layer in style_dict: layer_idx = style_layers.index(layer) style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer] return style_tensor def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None): edit_direction = get_direction(source_class, target_class, beta, di, clip_model) edit_full_dim = expand_to_full_dim(edit_direction) source_s = style_dict_to_style_tensor(source_latent, reference_generator) return source_s + alpha * edit_full_dim