import torch
from diffusers import StableDiffusionPipeline
import numpy as np
import abc
import time_utils
import copy
import os
from train_funcs import TRAIN_FUNC_DICT
## get arguments for our script
with_to_k = True
with_augs = True
train_func = "train_closed_form"
### load model
LOW_RESOURCE = True
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
tokenizer = ldm_stable.tokenizer
### get layers
ca_layers = []
def append_ca(net_):
if net_.__class__.__name__ == 'CrossAttention':
ca_layers.append(net_)
elif hasattr(net_, 'children'):
for net__ in net_.children():
append_ca(net__)
sub_nets = ldm_stable.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
append_ca(net[1])
elif "up" in net[0]:
append_ca(net[1])
elif "mid" in net[0]:
append_ca(net[1])
### get projection matrices
ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768]
projection_matrices = [l.to_v for l in ca_clip_layers]
og_matrices = [copy.deepcopy(l.to_v) for l in ca_clip_layers]
if with_to_k:
projection_matrices = projection_matrices + [l.to_k for l in ca_clip_layers]
og_matrices = og_matrices + [copy.deepcopy(l.to_k) for l in ca_clip_layers]
def edit_model(old_text_, new_text_, lamb=0.1):
#### restart LDM parameters
num_ca_clip_layers = len(ca_clip_layers)
for idx_, l in enumerate(ca_clip_layers):
l.to_v = copy.deepcopy(og_matrices[idx_])
projection_matrices[idx_] = l.to_v
if with_to_k:
l.to_k = copy.deepcopy(og_matrices[num_ca_clip_layers + idx_])
projection_matrices[num_ca_clip_layers + idx_] = l.to_k
try:
#### set up sentences
old_texts = [old_text_]
new_texts = [new_text_]
if with_augs:
base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:]
old_texts.append("A photo of " + base)
old_texts.append("An image of " + base)
old_texts.append("A picture of " + base)
base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:]
new_texts.append("A photo of " + base)
new_texts.append("An image of " + base)
new_texts.append("A picture of " + base)
#### prepare input k* and v*
old_embs, new_embs = [], []
for old_text, new_text in zip(old_texts, new_texts):
text_input = ldm_stable.tokenizer(
[old_text, new_text],
padding="max_length",
max_length=ldm_stable.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = ldm_stable.text_encoder(text_input.input_ids.to(ldm_stable.device))[0]
old_emb, new_emb = text_embeddings
old_embs.append(old_emb)
new_embs.append(new_emb)
#### indetify corresponding destinations for each token in old_emb
idxs_replaces = []
for old_text, new_text in zip(old_texts, new_texts):
tokens_a = tokenizer(old_text).input_ids
tokens_b = tokenizer(new_text).input_ids
tokens_a = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_a]
tokens_b = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_b]
num_orig_tokens = len(tokens_a)
num_new_tokens = len(tokens_b)
idxs_replace = []
j = 0
for i in range(num_orig_tokens):
curr_token = tokens_a[i]
while tokens_b[j] != curr_token:
j += 1
idxs_replace.append(j)
j += 1
while j < 77:
idxs_replace.append(j)
j += 1
while len(idxs_replace) < 77:
idxs_replace.append(76)
idxs_replaces.append(idxs_replace)
#### prepare batch: for each pair of setences, old context and new values
contexts, valuess = [], []
for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces):
context = old_emb.detach()
values = []
with torch.no_grad():
for layer in projection_matrices:
values.append(layer(new_emb[idxs_replace]).detach())
contexts.append(context)
valuess.append(values)
#### define training function
train = TRAIN_FUNC_DICT[train_func]
#### train the model
train(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts, lamb=lamb)
return f"Current model status: Edited \"{old_text_}\" into \"{new_text_}\""
except:
return "Current model status: An error occured"
def generate_for_text(test_text):
g = torch.Generator(device='cpu')
g.seed()
images = time_utils.text2image_ldm_stable(ldm_stable, [test_text], latent=None, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=g, low_resource=LOW_RESOURCE)
return time_utils.view_images(images)