import torch from diffusers import StableDiffusionPipeline import numpy as np import abc import time_utils import copy import os from train_funcs import TRAIN_FUNC_DICT ## get arguments for our script with_to_k = True with_augs = True train_func = "train_closed_form" ### load model LOW_RESOURCE = True NUM_DIFFUSION_STEPS = 50 GUIDANCE_SCALE = 7.5 MAX_NUM_WORDS = 77 device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device) tokenizer = ldm_stable.tokenizer ### get layers ca_layers = [] def append_ca(net_): if net_.__class__.__name__ == 'CrossAttention': ca_layers.append(net_) elif hasattr(net_, 'children'): for net__ in net_.children(): append_ca(net__) sub_nets = ldm_stable.unet.named_children() for net in sub_nets: if "down" in net[0]: append_ca(net[1]) elif "up" in net[0]: append_ca(net[1]) elif "mid" in net[0]: append_ca(net[1]) ### get projection matrices ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] projection_matrices = [l.to_v for l in ca_clip_layers] og_matrices = [copy.deepcopy(l.to_v) for l in ca_clip_layers] if with_to_k: projection_matrices = projection_matrices + [l.to_k for l in ca_clip_layers] og_matrices = og_matrices + [copy.deepcopy(l.to_k) for l in ca_clip_layers] def edit_model(old_text_, new_text_, lamb=0.1): #### restart LDM parameters num_ca_clip_layers = len(ca_clip_layers) for idx_, l in enumerate(ca_clip_layers): l.to_v = copy.deepcopy(og_matrices[idx_]) projection_matrices[idx_] = l.to_v if with_to_k: l.to_k = copy.deepcopy(og_matrices[num_ca_clip_layers + idx_]) projection_matrices[num_ca_clip_layers + idx_] = l.to_k try: #### set up sentences old_texts = [old_text_] new_texts = [new_text_] if with_augs: base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] old_texts.append("A photo of " + base) old_texts.append("An image of " + base) old_texts.append("A picture of " + base) base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] new_texts.append("A photo of " + base) new_texts.append("An image of " + base) new_texts.append("A picture of " + base) #### prepare input k* and v* old_embs, new_embs = [], [] for old_text, new_text in zip(old_texts, new_texts): text_input = ldm_stable.tokenizer( [old_text, new_text], padding="max_length", max_length=ldm_stable.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = ldm_stable.text_encoder(text_input.input_ids.to(ldm_stable.device))[0] old_emb, new_emb = text_embeddings old_embs.append(old_emb) new_embs.append(new_emb) #### indetify corresponding destinations for each token in old_emb idxs_replaces = [] for old_text, new_text in zip(old_texts, new_texts): tokens_a = tokenizer(old_text).input_ids tokens_b = tokenizer(new_text).input_ids tokens_a = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_a] tokens_b = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_b] num_orig_tokens = len(tokens_a) num_new_tokens = len(tokens_b) idxs_replace = [] j = 0 for i in range(num_orig_tokens): curr_token = tokens_a[i] while tokens_b[j] != curr_token: j += 1 idxs_replace.append(j) j += 1 while j < 77: idxs_replace.append(j) j += 1 while len(idxs_replace) < 77: idxs_replace.append(76) idxs_replaces.append(idxs_replace) #### prepare batch: for each pair of setences, old context and new values contexts, valuess = [], [] for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): context = old_emb.detach() values = [] with torch.no_grad(): for layer in projection_matrices: values.append(layer(new_emb[idxs_replace]).detach()) contexts.append(context) valuess.append(values) #### define training function train = TRAIN_FUNC_DICT[train_func] #### train the model train(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts, lamb=lamb) return f"Current model status: Edited \"{old_text_}\" into \"{new_text_}\"" except: return "Current model status: An error occured" def generate_for_text(test_text): g = torch.Generator(device='cpu') g.seed() images = time_utils.text2image_ldm_stable(ldm_stable, [test_text], latent=None, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=g, low_resource=LOW_RESOURCE) return time_utils.view_images(images)