Spaces:
Runtime error
Runtime error
Bahjat Kawar
commited on
Commit
•
3f7ead4
1
Parent(s):
aae0ff3
first commit
Browse files- README.md +1 -1
- app.py +35 -0
- requirements.txt +5 -0
- time_main.py +138 -0
- time_utils.py +105 -0
- train_funcs.py +65 -0
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🐢
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
|
|
1 |
---
|
2 |
+
title: Editing Implicit Assumptions in Text-to-Image Diffusion Models
|
3 |
emoji: 🐢
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from time_main import edit_model, generate_for_text
|
3 |
+
|
4 |
+
with gr.Blocks() as demo:
|
5 |
+
gr.Markdown("<center><h2>TIME: Text-to-Image Model Editing</h2>Demo for the paper <a href=\"https://time-diffusion.github.io/\" style=\"color:black;\">\"Editing Implicit Assumptions in Text-to-Image Diffusion Models\"</a>. Implemented with Stable Diffusion v1.4.</center>")
|
6 |
+
|
7 |
+
with gr.Box():
|
8 |
+
gr.Markdown("1. Edit a concept in a text-to-image model by specifying an under-specified \"source\" prompt, and a similar \"destination\" prompt with an additional specification.")
|
9 |
+
with gr.Row():
|
10 |
+
src = gr.Textbox(label = "Source Prompt", placeholder="e.g., A pack of roses")
|
11 |
+
dst = gr.Textbox(label = "Destination Prompt", placeholder="e.g., A pack of blue roses")
|
12 |
+
with gr.Row():
|
13 |
+
lamb_val = gr.Slider(value = 0.1, minimum=0.01, maximum=10000, label = "Strength of regularization (lambda)", interactive = True)
|
14 |
+
with gr.Row():
|
15 |
+
edit_btn = gr.Button("Edit Model")
|
16 |
+
with gr.Row():
|
17 |
+
gr.HTML(value = "<br />")
|
18 |
+
with gr.Row():
|
19 |
+
edit_status = gr.HTML(value="<b>Current model status:</b> Unedited")
|
20 |
+
edit_btn.click(fn=edit_model, inputs=[src, dst, lamb_val], outputs=edit_status)
|
21 |
+
|
22 |
+
with gr.Box():
|
23 |
+
gr.Markdown("2. After editing, try any test prompt and see the effect on the generated images!")
|
24 |
+
with gr.Row():
|
25 |
+
tst = gr.Textbox(label = "Test Prompt", placeholder="e.g., A field of roses")
|
26 |
+
with gr.Row():
|
27 |
+
gen_btn = gr.Button("Generate Image")
|
28 |
+
with gr.Row():
|
29 |
+
gr.HTML(value = "<br />")
|
30 |
+
with gr.Row():
|
31 |
+
out_img = gr.Image(label="Generated Image")
|
32 |
+
|
33 |
+
gen_btn.click(fn=generate_for_text, inputs=tst, outputs=out_img)
|
34 |
+
|
35 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch==1.13.1
|
3 |
+
diffusers
|
4 |
+
numpy
|
5 |
+
Pillow
|
time_main.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
import numpy as np
|
4 |
+
import abc
|
5 |
+
import time_utils
|
6 |
+
import copy
|
7 |
+
import os
|
8 |
+
from train_funcs import TRAIN_FUNC_DICT
|
9 |
+
|
10 |
+
## get arguments for our script
|
11 |
+
with_to_k = True
|
12 |
+
with_augs = True
|
13 |
+
train_func = "train_closed_form"
|
14 |
+
|
15 |
+
### load model
|
16 |
+
LOW_RESOURCE = True
|
17 |
+
NUM_DIFFUSION_STEPS = 50
|
18 |
+
GUIDANCE_SCALE = 7.5
|
19 |
+
MAX_NUM_WORDS = 77
|
20 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
21 |
+
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
|
22 |
+
tokenizer = ldm_stable.tokenizer
|
23 |
+
|
24 |
+
### get layers
|
25 |
+
ca_layers = []
|
26 |
+
def append_ca(net_):
|
27 |
+
if net_.__class__.__name__ == 'CrossAttention':
|
28 |
+
ca_layers.append(net_)
|
29 |
+
elif hasattr(net_, 'children'):
|
30 |
+
for net__ in net_.children():
|
31 |
+
append_ca(net__)
|
32 |
+
|
33 |
+
sub_nets = ldm_stable.unet.named_children()
|
34 |
+
for net in sub_nets:
|
35 |
+
if "down" in net[0]:
|
36 |
+
append_ca(net[1])
|
37 |
+
elif "up" in net[0]:
|
38 |
+
append_ca(net[1])
|
39 |
+
elif "mid" in net[0]:
|
40 |
+
append_ca(net[1])
|
41 |
+
|
42 |
+
### get projection matrices
|
43 |
+
ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768]
|
44 |
+
projection_matrices = [l.to_v for l in ca_clip_layers]
|
45 |
+
og_matrices = [copy.deepcopy(l.to_v) for l in ca_clip_layers]
|
46 |
+
if with_to_k:
|
47 |
+
projection_matrices = projection_matrices + [l.to_k for l in ca_clip_layers]
|
48 |
+
og_matrices = og_matrices + [copy.deepcopy(l.to_k) for l in ca_clip_layers]
|
49 |
+
|
50 |
+
def edit_model(old_text_, new_text_, lamb=0.1):
|
51 |
+
#### restart LDM parameters
|
52 |
+
num_ca_clip_layers = len(ca_clip_layers)
|
53 |
+
for idx_, l in enumerate(ca_clip_layers):
|
54 |
+
l.to_v = copy.deepcopy(og_matrices[idx_])
|
55 |
+
projection_matrices[idx_] = l.to_v
|
56 |
+
if with_to_k:
|
57 |
+
l.to_k = copy.deepcopy(og_matrices[num_ca_clip_layers + idx_])
|
58 |
+
projection_matrices[num_ca_clip_layers + idx_] = l.to_k
|
59 |
+
|
60 |
+
try:
|
61 |
+
#### set up sentences
|
62 |
+
old_texts = [old_text_]
|
63 |
+
new_texts = [new_text_]
|
64 |
+
if with_augs:
|
65 |
+
base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:]
|
66 |
+
old_texts.append("A photo of " + base)
|
67 |
+
old_texts.append("An image of " + base)
|
68 |
+
old_texts.append("A picture of " + base)
|
69 |
+
base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:]
|
70 |
+
new_texts.append("A photo of " + base)
|
71 |
+
new_texts.append("An image of " + base)
|
72 |
+
new_texts.append("A picture of " + base)
|
73 |
+
|
74 |
+
#### prepare input k* and v*
|
75 |
+
old_embs, new_embs = [], []
|
76 |
+
for old_text, new_text in zip(old_texts, new_texts):
|
77 |
+
text_input = ldm_stable.tokenizer(
|
78 |
+
[old_text, new_text],
|
79 |
+
padding="max_length",
|
80 |
+
max_length=ldm_stable.tokenizer.model_max_length,
|
81 |
+
truncation=True,
|
82 |
+
return_tensors="pt",
|
83 |
+
)
|
84 |
+
text_embeddings = ldm_stable.text_encoder(text_input.input_ids.to(ldm_stable.device))[0]
|
85 |
+
old_emb, new_emb = text_embeddings
|
86 |
+
old_embs.append(old_emb)
|
87 |
+
new_embs.append(new_emb)
|
88 |
+
|
89 |
+
#### indetify corresponding destinations for each token in old_emb
|
90 |
+
idxs_replaces = []
|
91 |
+
for old_text, new_text in zip(old_texts, new_texts):
|
92 |
+
tokens_a = tokenizer(old_text).input_ids
|
93 |
+
tokens_b = tokenizer(new_text).input_ids
|
94 |
+
tokens_a = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_a]
|
95 |
+
tokens_b = [tokenizer.encode("a ")[1] if tokenizer.decode(t) == 'an' else t for t in tokens_b]
|
96 |
+
num_orig_tokens = len(tokens_a)
|
97 |
+
num_new_tokens = len(tokens_b)
|
98 |
+
idxs_replace = []
|
99 |
+
j = 0
|
100 |
+
for i in range(num_orig_tokens):
|
101 |
+
curr_token = tokens_a[i]
|
102 |
+
while tokens_b[j] != curr_token:
|
103 |
+
j += 1
|
104 |
+
idxs_replace.append(j)
|
105 |
+
j += 1
|
106 |
+
while j < 77:
|
107 |
+
idxs_replace.append(j)
|
108 |
+
j += 1
|
109 |
+
while len(idxs_replace) < 77:
|
110 |
+
idxs_replace.append(76)
|
111 |
+
idxs_replaces.append(idxs_replace)
|
112 |
+
|
113 |
+
#### prepare batch: for each pair of setences, old context and new values
|
114 |
+
contexts, valuess = [], []
|
115 |
+
for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces):
|
116 |
+
context = old_emb.detach()
|
117 |
+
values = []
|
118 |
+
with torch.no_grad():
|
119 |
+
for layer in projection_matrices:
|
120 |
+
values.append(layer(new_emb[idxs_replace]).detach())
|
121 |
+
contexts.append(context)
|
122 |
+
valuess.append(values)
|
123 |
+
|
124 |
+
#### define training function
|
125 |
+
train = TRAIN_FUNC_DICT[train_func]
|
126 |
+
|
127 |
+
#### train the model
|
128 |
+
train(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts, lamb=lamb)
|
129 |
+
|
130 |
+
return f"<b>Current model status:</b> Edited \"{old_text_}\" into \"{new_text_}\""
|
131 |
+
except:
|
132 |
+
return "<b>Current model status:</b> An error occured"
|
133 |
+
|
134 |
+
def generate_for_text(test_text):
|
135 |
+
g = torch.Generator(device='cpu')
|
136 |
+
g.seed()
|
137 |
+
images = time_utils.text2image_ldm_stable(ldm_stable, [test_text], latent=None, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=g, low_resource=LOW_RESOURCE)
|
138 |
+
return time_utils.view_images(images)
|
time_utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def view_images(images, num_rows=1, offset_ratio=0.02):
|
7 |
+
if type(images) is list:
|
8 |
+
num_empty = len(images) % num_rows
|
9 |
+
elif images.ndim == 4:
|
10 |
+
num_empty = images.shape[0] % num_rows
|
11 |
+
else:
|
12 |
+
images = [images]
|
13 |
+
num_empty = 0
|
14 |
+
|
15 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
16 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
17 |
+
num_items = len(images)
|
18 |
+
|
19 |
+
h, w, c = images[0].shape
|
20 |
+
offset = int(h * offset_ratio)
|
21 |
+
num_cols = num_items // num_rows
|
22 |
+
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
23 |
+
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
24 |
+
for i in range(num_rows):
|
25 |
+
for j in range(num_cols):
|
26 |
+
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
27 |
+
i * num_cols + j]
|
28 |
+
|
29 |
+
pil_img = Image.fromarray(image_)
|
30 |
+
return pil_img
|
31 |
+
|
32 |
+
|
33 |
+
def diffusion_step(model, latents, context, t, guidance_scale, low_resource=False):
|
34 |
+
if low_resource:
|
35 |
+
noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
|
36 |
+
noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
|
37 |
+
else:
|
38 |
+
latents_input = torch.cat([latents] * 2)
|
39 |
+
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
40 |
+
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
41 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
42 |
+
latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
|
43 |
+
return latents
|
44 |
+
|
45 |
+
|
46 |
+
def latent2image(vae, latents):
|
47 |
+
latents = 1 / 0.18215 * latents
|
48 |
+
image = vae.decode(latents)['sample']
|
49 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
50 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
51 |
+
image = (image * 255).astype(np.uint8)
|
52 |
+
return image
|
53 |
+
|
54 |
+
|
55 |
+
def init_latent(latent, model, height, width, generator, batch_size):
|
56 |
+
if latent is None:
|
57 |
+
latent = torch.randn(
|
58 |
+
(1, model.unet.in_channels, height // 8, width // 8),
|
59 |
+
generator=generator,
|
60 |
+
)
|
61 |
+
latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
|
62 |
+
return latent, latents
|
63 |
+
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def text2image_ldm_stable(
|
67 |
+
model,
|
68 |
+
prompt,
|
69 |
+
num_inference_steps = 50,
|
70 |
+
guidance_scale = 7.5,
|
71 |
+
generator = None,
|
72 |
+
latent = None,
|
73 |
+
low_resource = False,
|
74 |
+
):
|
75 |
+
height = width = 512
|
76 |
+
batch_size = len(prompt)
|
77 |
+
|
78 |
+
text_input = model.tokenizer(
|
79 |
+
prompt,
|
80 |
+
padding="max_length",
|
81 |
+
max_length=model.tokenizer.model_max_length,
|
82 |
+
truncation=True,
|
83 |
+
return_tensors="pt",
|
84 |
+
)
|
85 |
+
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
86 |
+
max_length = text_input.input_ids.shape[-1]
|
87 |
+
uncond_input = model.tokenizer(
|
88 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
89 |
+
)
|
90 |
+
uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
|
91 |
+
|
92 |
+
context = [uncond_embeddings, text_embeddings]
|
93 |
+
if not low_resource:
|
94 |
+
context = torch.cat(context)
|
95 |
+
latent, latents = init_latent(latent, model, height, width, generator, batch_size)
|
96 |
+
|
97 |
+
model.scheduler.set_timesteps(num_inference_steps)
|
98 |
+
for t in model.scheduler.timesteps:
|
99 |
+
latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource)
|
100 |
+
|
101 |
+
image = latent2image(model.vae, latents)
|
102 |
+
|
103 |
+
image, _ = model.run_safety_checker(image=image, device=model.device, dtype=text_embeddings.dtype)
|
104 |
+
|
105 |
+
return image
|
train_funcs.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import ast
|
4 |
+
|
5 |
+
"""
|
6 |
+
TRAIN FUNCTION DEFINITION:
|
7 |
+
train(model: StableDiffusionPipeline,
|
8 |
+
projection_matrices: list[size=L](nn.Module),
|
9 |
+
og_matrices: list[size=L](nn.Module),
|
10 |
+
contexts: list[size=N](torch.tensor[size=MAX_LEN,...]),
|
11 |
+
valuess: list[size=N](list[size=L](torch.tensor[size=MAX_LEN,...])),
|
12 |
+
old_texts: list[size=N](str),
|
13 |
+
new_texts: list[size=N](str),
|
14 |
+
**kwargs)
|
15 |
+
where L is the number of matrices to edit, and N is the number of sentences to train on (batch size).
|
16 |
+
|
17 |
+
PARAMS:
|
18 |
+
model: the model to use.
|
19 |
+
projection_matrices: list of projection matrices to edit from the model.
|
20 |
+
og_matrices: list of original values for the projection matrices. detached from the model.
|
21 |
+
contexts: list of context vectors (inputs to the matrices) to edit.
|
22 |
+
valuess: list of results from all matrices for each context vector.
|
23 |
+
old_texts: list of sentences to be edited.
|
24 |
+
new_texts: list of target sentences to be aimed at.
|
25 |
+
**kwargs: additional command line arguments.
|
26 |
+
|
27 |
+
TRAIN_FUNC_DICT defined at the bottom of the file.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def baseline_train(model, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts):
|
31 |
+
return None
|
32 |
+
|
33 |
+
def train_closed_form(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts,
|
34 |
+
new_texts, layers_to_edit=None, lamb=0.1):
|
35 |
+
layers_to_edit = ast.literal_eval(layers_to_edit) if type(layers_to_edit) == str else layers_to_edit
|
36 |
+
lamb = ast.literal_eval(lamb) if type(lamb) == str else lamb
|
37 |
+
|
38 |
+
for layer_num in range(len(projection_matrices)):
|
39 |
+
if (layers_to_edit is not None) and (layer_num not in layers_to_edit):
|
40 |
+
continue
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
#mat1 = \lambda W + \sum{v k^T}
|
44 |
+
mat1 = lamb * projection_matrices[layer_num].weight
|
45 |
+
|
46 |
+
#mat2 = \lambda I + \sum{k k^T}
|
47 |
+
mat2 = lamb * torch.eye(projection_matrices[layer_num].weight.shape[1], device = projection_matrices[layer_num].weight.device)
|
48 |
+
|
49 |
+
#aggregate sums for mat1, mat2
|
50 |
+
for context, values in zip(contexts, valuess):
|
51 |
+
context_vector = context.reshape(context.shape[0], context.shape[1], 1)
|
52 |
+
context_vector_T = context.reshape(context.shape[0], 1, context.shape[1])
|
53 |
+
value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1)
|
54 |
+
for_mat1 = (value_vector @ context_vector_T).sum(dim=0)
|
55 |
+
for_mat2 = (context_vector @ context_vector_T).sum(dim=0)
|
56 |
+
mat1 += for_mat1
|
57 |
+
mat2 += for_mat2
|
58 |
+
|
59 |
+
#update projection matrix
|
60 |
+
projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2))
|
61 |
+
|
62 |
+
TRAIN_FUNC_DICT = {
|
63 |
+
"baseline": baseline_train,
|
64 |
+
"train_closed_form": train_closed_form,
|
65 |
+
}
|