Bahjat Kawar commited on
Commit
3f7ead4
1 Parent(s): aae0ff3

first commit

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. app.py +35 -0
  3. requirements.txt +5 -0
  4. time_main.py +138 -0
  5. time_utils.py +105 -0
  6. train_funcs.py +65 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Time Diffusion
3
  emoji: 🐢
4
  colorFrom: red
5
  colorTo: red
1
  ---
2
+ title: Editing Implicit Assumptions in Text-to-Image Diffusion Models
3
  emoji: 🐢
4
  colorFrom: red
5
  colorTo: red
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from time_main import edit_model, generate_for_text
3
+
4
+ with gr.Blocks() as demo:
5
+ gr.Markdown("<center><h2>TIME: Text-to-Image Model Editing</h2>Demo for the paper <a href=\"https://time-diffusion.github.io/\" style=\"color:black;\">\"Editing Implicit Assumptions in Text-to-Image Diffusion Models\"</a>. Implemented with Stable Diffusion v1.4.</center>")
6
+
7
+ with gr.Box():
8
+ gr.Markdown("1. Edit a concept in a text-to-image model by specifying an under-specified \"source\" prompt, and a similar \"destination\" prompt with an additional specification.")
9
+ with gr.Row():
10
+ src = gr.Textbox(label = "Source Prompt", placeholder="e.g., A pack of roses")
11
+ dst = gr.Textbox(label = "Destination Prompt", placeholder="e.g., A pack of blue roses")
12
+ with gr.Row():
13
+ lamb_val = gr.Slider(value = 0.1, minimum=0.01, maximum=10000, label = "Strength of regularization (lambda)", interactive = True)
14
+ with gr.Row():
15
+ edit_btn = gr.Button("Edit Model")
16
+ with gr.Row():
17
+ gr.HTML(value = "<br />")
18
+ with gr.Row():
19
+ edit_status = gr.HTML(value="<b>Current model status:</b> Unedited")
20
+ edit_btn.click(fn=edit_model, inputs=[src, dst, lamb_val], outputs=edit_status)
21
+
22
+ with gr.Box():
23
+ gr.Markdown("2. After editing, try any test prompt and see the effect on the generated images!")
24
+ with gr.Row():
25
+ tst = gr.Textbox(label = "Test Prompt", placeholder="e.g., A field of roses")
26
+ with gr.Row():
27
+ gen_btn = gr.Button("Generate Image")
28
+ with gr.Row():
29
+ gr.HTML(value = "<br />")
30
+ with gr.Row():
31
+ out_img = gr.Image(label="Generated Image")
32
+
33
+ gen_btn.click(fn=generate_for_text, inputs=tst, outputs=out_img)
34
+
35
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.13.1
3
+ diffusers
4
+ numpy
5
+ Pillow
time_main.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+ import numpy as np
4
+ import abc
5
+ import time_utils
6
+ import copy
7
+ import os
8
+ from train_funcs import TRAIN_FUNC_DICT
9
+
10
+ ## get arguments for our script
11
+ with_to_k = True
12
+ with_augs = True
13
+ train_func = "train_closed_form"
14
+
15
+ ### load model
16
+ LOW_RESOURCE = True
17
+ NUM_DIFFUSION_STEPS = 50
18
+ GUIDANCE_SCALE = 7.5
19
+ MAX_NUM_WORDS = 77
20
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
21
+ ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
22
+ tokenizer = ldm_stable.tokenizer
23
+
24
+ ### get layers
25
+ ca_layers = []
26
+ def append_ca(net_):
27
+ if net_.__class__.__name__ == 'CrossAttention':
28
+ ca_layers.append(net_)
29
+ elif hasattr(net_, 'children'):
30
+ for net__ in net_.children():
31
+ append_ca(net__)
32
+
33
+ sub_nets = ldm_stable.unet.named_children()
34
+ for net in sub_nets:
35
+ if "down" in net[0]:
36
+ append_ca(net[1])
37
+ elif "up" in net[0]:
38
+ append_ca(net[1])
39
+ elif "mid" in net[0]:
40
+ append_ca(net[1])
41
+
42
+ ### get projection matrices
43
+ ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768]
44
+ projection_matrices = [l.to_v for l in ca_clip_layers]
45
+ og_matrices = [copy.deepcopy(l.to_v) for l in ca_clip_layers]
46
+ if with_to_k:
47
+ projection_matrices = projection_matrices + [l.to_k for l in ca_clip_layers]
48
+ og_matrices = og_matrices + [copy.deepcopy(l.to_k) for l in ca_clip_layers]
49
+
50
+ def edit_model(old_text_, new_text_, lamb=0.1):
51
+ #### restart LDM parameters
52
+ num_ca_clip_layers = len(ca_clip_layers)
53
+ for idx_, l in enumerate(ca_clip_layers):
54
+ l.to_v = copy.deepcopy(og_matrices[idx_])
55
+ projection_matrices[idx_] = l.to_v
56
+ if with_to_k:
57
+ l.to_k = copy.deepcopy(og_matrices[num_ca_clip_layers + idx_])
58
+ projection_matrices[num_ca_clip_layers + idx_] = l.to_k
59
+
60
+ try:
61
+ #### set up sentences
62
+ old_texts = [old_text_]
63
+ new_texts = [new_text_]
64
+ if with_augs:
65
+ base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:]
66
+ old_texts.append("A photo of " + base)
67
+ old_texts.append("An image of " + base)
68
+ old_texts.append("A picture of " + base)
69
+ base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:]
70
+ new_texts.append("A photo of " + base)
71
+ new_texts.append("An image of " + base)
72
+ new_texts.append("A picture of " + base)
73
+
74
+ #### prepare input k* and v*
75
+ old_embs, new_embs = [], []
76
+ for old_text, new_text in zip(old_texts, new_texts):
77
+ text_input = ldm_stable.tokenizer(
78
+ [old_text, new_text],
79
+ padding="max_length",
80
+ max_length=ldm_stable.tokenizer.model_max_length,
81
+ truncation=True,
82
+ return_tensors="pt",
83
+ )
84
+ text_embeddings = ldm_stable.text_encoder(text_input.input_ids.to(ldm_stable.device))[0]
85
+ old_emb, new_emb = text_embeddings
86
+ old_embs.append(old_emb)
87
+ new_embs.append(new_emb)
88
+
89
+ #### indetify corresponding destinations for each token in old_emb
90
+ idxs_replaces = []
91
+ for old_text, new_text in zip(old_texts, new_texts):
92
+ tokens_a = tokenizer(old_text).input_ids
93
+ tokens_b = tokenizer(new_text).input_ids
94
+ tokens_a = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_a]
95
+ tokens_b = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_b]
96
+ num_orig_tokens = len(tokens_a)
97
+ num_new_tokens = len(tokens_b)
98
+ idxs_replace = []
99
+ j = 0
100
+ for i in range(num_orig_tokens):
101
+ curr_token = tokens_a[i]
102
+ while tokens_b[j] != curr_token:
103
+ j += 1
104
+ idxs_replace.append(j)
105
+ j += 1
106
+ while j < 77:
107
+ idxs_replace.append(j)
108
+ j += 1
109
+ while len(idxs_replace) < 77:
110
+ idxs_replace.append(76)
111
+ idxs_replaces.append(idxs_replace)
112
+
113
+ #### prepare batch: for each pair of setences, old context and new values
114
+ contexts, valuess = [], []
115
+ for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces):
116
+ context = old_emb.detach()
117
+ values = []
118
+ with torch.no_grad():
119
+ for layer in projection_matrices:
120
+ values.append(layer(new_emb[idxs_replace]).detach())
121
+ contexts.append(context)
122
+ valuess.append(values)
123
+
124
+ #### define training function
125
+ train = TRAIN_FUNC_DICT[train_func]
126
+
127
+ #### train the model
128
+ train(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts, lamb=lamb)
129
+
130
+ return f"<b>Current model status:</b> Edited \"{old_text_}\" into \"{new_text_}\""
131
+ except:
132
+ return "<b>Current model status:</b> An error occured"
133
+
134
+ def generate_for_text(test_text):
135
+ g = torch.Generator(device='cpu')
136
+ g.seed()
137
+ images = time_utils.text2image_ldm_stable(ldm_stable, [test_text], latent=None, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=g, low_resource=LOW_RESOURCE)
138
+ return time_utils.view_images(images)
time_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+
5
+
6
+ def view_images(images, num_rows=1, offset_ratio=0.02):
7
+ if type(images) is list:
8
+ num_empty = len(images) % num_rows
9
+ elif images.ndim == 4:
10
+ num_empty = images.shape[0] % num_rows
11
+ else:
12
+ images = [images]
13
+ num_empty = 0
14
+
15
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
16
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
17
+ num_items = len(images)
18
+
19
+ h, w, c = images[0].shape
20
+ offset = int(h * offset_ratio)
21
+ num_cols = num_items // num_rows
22
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
23
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
24
+ for i in range(num_rows):
25
+ for j in range(num_cols):
26
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
27
+ i * num_cols + j]
28
+
29
+ pil_img = Image.fromarray(image_)
30
+ return pil_img
31
+
32
+
33
+ def diffusion_step(model, latents, context, t, guidance_scale, low_resource=False):
34
+ if low_resource:
35
+ noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
36
+ noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
37
+ else:
38
+ latents_input = torch.cat([latents] * 2)
39
+ noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
40
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
41
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
42
+ latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
43
+ return latents
44
+
45
+
46
+ def latent2image(vae, latents):
47
+ latents = 1 / 0.18215 * latents
48
+ image = vae.decode(latents)['sample']
49
+ image = (image / 2 + 0.5).clamp(0, 1)
50
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
51
+ image = (image * 255).astype(np.uint8)
52
+ return image
53
+
54
+
55
+ def init_latent(latent, model, height, width, generator, batch_size):
56
+ if latent is None:
57
+ latent = torch.randn(
58
+ (1, model.unet.in_channels, height // 8, width // 8),
59
+ generator=generator,
60
+ )
61
+ latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
62
+ return latent, latents
63
+
64
+
65
+ @torch.no_grad()
66
+ def text2image_ldm_stable(
67
+ model,
68
+ prompt,
69
+ num_inference_steps = 50,
70
+ guidance_scale = 7.5,
71
+ generator = None,
72
+ latent = None,
73
+ low_resource = False,
74
+ ):
75
+ height = width = 512
76
+ batch_size = len(prompt)
77
+
78
+ text_input = model.tokenizer(
79
+ prompt,
80
+ padding="max_length",
81
+ max_length=model.tokenizer.model_max_length,
82
+ truncation=True,
83
+ return_tensors="pt",
84
+ )
85
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
86
+ max_length = text_input.input_ids.shape[-1]
87
+ uncond_input = model.tokenizer(
88
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
89
+ )
90
+ uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
91
+
92
+ context = [uncond_embeddings, text_embeddings]
93
+ if not low_resource:
94
+ context = torch.cat(context)
95
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
96
+
97
+ model.scheduler.set_timesteps(num_inference_steps)
98
+ for t in model.scheduler.timesteps:
99
+ latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource)
100
+
101
+ image = latent2image(model.vae, latents)
102
+
103
+ image, _ = model.run_safety_checker(image=image, device=model.device, dtype=text_embeddings.dtype)
104
+
105
+ return image
train_funcs.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import ast
4
+
5
+ """
6
+ TRAIN FUNCTION DEFINITION:
7
+ train(model: StableDiffusionPipeline,
8
+ projection_matrices: list[size=L](nn.Module),
9
+ og_matrices: list[size=L](nn.Module),
10
+ contexts: list[size=N](torch.tensor[size=MAX_LEN,...]),
11
+ valuess: list[size=N](list[size=L](torch.tensor[size=MAX_LEN,...])),
12
+ old_texts: list[size=N](str),
13
+ new_texts: list[size=N](str),
14
+ **kwargs)
15
+ where L is the number of matrices to edit, and N is the number of sentences to train on (batch size).
16
+
17
+ PARAMS:
18
+ model: the model to use.
19
+ projection_matrices: list of projection matrices to edit from the model.
20
+ og_matrices: list of original values for the projection matrices. detached from the model.
21
+ contexts: list of context vectors (inputs to the matrices) to edit.
22
+ valuess: list of results from all matrices for each context vector.
23
+ old_texts: list of sentences to be edited.
24
+ new_texts: list of target sentences to be aimed at.
25
+ **kwargs: additional command line arguments.
26
+
27
+ TRAIN_FUNC_DICT defined at the bottom of the file.
28
+ """
29
+
30
+ def baseline_train(model, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts):
31
+ return None
32
+
33
+ def train_closed_form(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts,
34
+ new_texts, layers_to_edit=None, lamb=0.1):
35
+ layers_to_edit = ast.literal_eval(layers_to_edit) if type(layers_to_edit) == str else layers_to_edit
36
+ lamb = ast.literal_eval(lamb) if type(lamb) == str else lamb
37
+
38
+ for layer_num in range(len(projection_matrices)):
39
+ if (layers_to_edit is not None) and (layer_num not in layers_to_edit):
40
+ continue
41
+
42
+ with torch.no_grad():
43
+ #mat1 = \lambda W + \sum{v k^T}
44
+ mat1 = lamb * projection_matrices[layer_num].weight
45
+
46
+ #mat2 = \lambda I + \sum{k k^T}
47
+ mat2 = lamb * torch.eye(projection_matrices[layer_num].weight.shape[1], device = projection_matrices[layer_num].weight.device)
48
+
49
+ #aggregate sums for mat1, mat2
50
+ for context, values in zip(contexts, valuess):
51
+ context_vector = context.reshape(context.shape[0], context.shape[1], 1)
52
+ context_vector_T = context.reshape(context.shape[0], 1, context.shape[1])
53
+ value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1)
54
+ for_mat1 = (value_vector @ context_vector_T).sum(dim=0)
55
+ for_mat2 = (context_vector @ context_vector_T).sum(dim=0)
56
+ mat1 += for_mat1
57
+ mat2 += for_mat2
58
+
59
+ #update projection matrix
60
+ projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2))
61
+
62
+ TRAIN_FUNC_DICT = {
63
+ "baseline": baseline_train,
64
+ "train_closed_form": train_closed_form,
65
+ }