dbaranchuk commited on
Commit
1df36a0
·
verified ·
1 Parent(s): 6be27b5

Main update

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +81 -67
  3. generation_sdxl.py +474 -0
  4. requirements.txt +1 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: ICD Image Generation
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
 
1
  ---
2
+ title: Demo App
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
app.py CHANGED
@@ -1,46 +1,76 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline
 
 
5
  import torch
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
 
 
 
 
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
22
 
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
 
26
  generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
  "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
 
 
44
  ]
45
 
46
  css="""
@@ -58,11 +88,20 @@ else:
58
  with gr.Blocks(css=css) as demo:
59
 
60
  with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
 
 
 
 
 
63
  Currently running on {power_device}.
64
- """)
65
-
 
 
 
 
66
  with gr.Row():
67
 
68
  prompt = gr.Text(
@@ -79,13 +118,6 @@ with gr.Blocks(css=css) as demo:
79
 
80
  with gr.Accordion("Advanced Settings", open=False):
81
 
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
  seed = gr.Slider(
90
  label="Seed",
91
  minimum=0,
@@ -94,53 +126,35 @@ with gr.Blocks(css=css) as demo:
94
  value=0,
95
  )
96
 
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
- with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
 
117
  with gr.Row():
118
 
119
  guidance_scale = gr.Slider(
120
  label="Guidance scale",
121
  minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
  )
126
 
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
  )
134
 
135
  gr.Examples(
136
  examples = examples,
137
- inputs = [prompt]
 
138
  )
139
-
140
  run_button.click(
141
  fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
  outputs = [result]
144
  )
145
 
146
- demo.queue().launch()
 
1
+ import spaces
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
+ import generation_sdxl
6
+ import functools
7
+ from diffusers import DiffusionPipeline, UNet2DConditionModel, StableDiffusionXLPipeline, DDIMScheduler
8
  import torch
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ torch.cuda.max_memory_allocated(device=device)
13
+ model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
14
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id,
15
+ torch_dtype=torch.float16,
16
+ scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"),
17
+ variant="fp16").to(device)
18
+ pipe = pipe.to(device)
19
+ unet = UNet2DConditionModel.from_pretrained("dbaranchuk/sdxl-cfg-distill-unet").to(device)
20
+ pipe.unet = unet
21
+ pipe.load_lora_weights("dbaranchuk/icd-lora-sdxl",
22
+ weight_name='reverse-249-499-699-999.safetensors')
23
+ pipe.fuse_lora()
24
+ pipe.to(dtype=torch.float16, device=device)
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
28
 
29
+ @spaces.GPU(duration=30)
30
+ def infer(prompt, seed, randomize_seed, tau,
31
+ guidance_scale):
32
 
33
  if randomize_seed:
34
  seed = random.randint(0, MAX_SEED)
35
 
36
  generator = torch.Generator().manual_seed(seed)
37
+ prompt = [prompt]
38
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
39
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
40
+
41
+ compute_embeddings_fn = functools.partial(
42
+ generation_sdxl.compute_embeddings,
43
+ proportion_empty_prompts=0,
44
+ text_encoders=text_encoders,
45
+ tokenizers=tokenizers,
46
+ )
47
+
48
+ if tau < 1.0:
49
+ use_dynamic_guidance=True
50
+ else:
51
+ use_dynamic_guidance=False
52
+
53
+ images = generation_sdxl.sample_deterministic(
54
+ pipe,
55
+ prompt,
56
+ num_inference_steps=4,
57
+ generator=generator,
58
+ guidance_scale=guidance_scale,
59
+ is_sdxl=True,
60
+ timesteps=[249, 499, 699, 999],
61
+ use_dynamic_guidance=use_dynamic_guidance,
62
+ tau1=tau,
63
+ tau2=tau,
64
+ compute_embeddings_fn=compute_embeddings_fn
65
+ )[0]
66
+
67
+ return images
68
 
69
  examples = [
 
70
  "An astronaut riding a green horse",
71
+ 'Long-exposure night photography of a starry sky over a mountain range, with light trails.',
72
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
73
+ "A portrait of a girl with blonde, tousled hair, blue eyes",
74
  ]
75
 
76
  css="""
 
88
  with gr.Blocks(css=css) as demo:
89
 
90
  with gr.Column(elem_id="col-container"):
91
+ gr.Markdown(
92
+ f"""
93
+ # ⚡ Invertible Consistency Distillation ⚡
94
+ # ⚡ Image Generation with 4-step iCD-XL ⚡
95
+ This is a demo of [Invertible Consistency Distillation](https://yandex-research.github.io/invertible-cd/),
96
+ a diffusion distillation method proposed in [Invertible Consistency Distillation for Text-Guided Image Editing in Around 7 Steps](https://arxiv.org/abs/2406.14539)
97
+ by [Yandex Research](https://github.com/yandex-research).
98
  Currently running on {power_device}.
99
+ """
100
+ )
101
+ gr.Markdown(
102
+ "If you enjoy the space, feel free to give a ⭐ to the <a href='https://github.com/yandex-research/invertible-cd' target='_blank'>Github Repo</a>. [![GitHub Stars](https://img.shields.io/github/stars/yandex-research/invertible-cd?style=social)](https://github.com/yandex-research/invertible-cd)"
103
+ )
104
+
105
  with gr.Row():
106
 
107
  prompt = gr.Text(
 
118
 
119
  with gr.Accordion("Advanced Settings", open=False):
120
 
 
 
 
 
 
 
 
121
  seed = gr.Slider(
122
  label="Seed",
123
  minimum=0,
 
126
  value=0,
127
  )
128
 
129
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  with gr.Row():
132
 
133
  guidance_scale = gr.Slider(
134
  label="Guidance scale",
135
  minimum=0.0,
136
+ maximum=19.0,
137
+ step=1.0,
138
+ value=7.0,
139
  )
140
 
141
+ dynamic_guidance_tau = gr.Slider(
142
+ label="Dynamic guidance tau",
143
+ minimum=0,
144
+ maximum=1,
145
+ step=0.1,
146
+ value=1.0,
147
  )
148
 
149
  gr.Examples(
150
  examples = examples,
151
+ inputs = [prompt],
152
+ cache_examples=False
153
  )
 
154
  run_button.click(
155
  fn = infer,
156
+ inputs = [prompt, seed, randomize_seed, dynamic_guidance_tau, guidance_scale],
157
  outputs = [result]
158
  )
159
 
160
+ demo.queue().launch(share=False)
generation_sdxl.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import random
4
+ import numpy as np
5
+
6
+
7
+ # Diffusion util
8
+ # ------------------------------------------------------------------------
9
+ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
10
+ prompt_embeds_list = []
11
+
12
+ captions = []
13
+ for caption in prompt_batch:
14
+ if random.random() < proportion_empty_prompts:
15
+ captions.append("")
16
+ elif isinstance(caption, str):
17
+ captions.append(caption)
18
+ elif isinstance(caption, (list, np.ndarray)):
19
+ # take a random caption if there are multiple
20
+ captions.append(random.choice(caption) if is_train else caption[0])
21
+
22
+ with torch.no_grad():
23
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
24
+ text_inputs = tokenizer(
25
+ captions,
26
+ padding="max_length",
27
+ max_length=tokenizer.model_max_length,
28
+ truncation=True,
29
+ return_tensors="pt",
30
+ )
31
+ text_input_ids = text_inputs.input_ids
32
+ prompt_embeds = text_encoder(
33
+ text_input_ids.to(text_encoder.device),
34
+ output_hidden_states=True,
35
+ )
36
+
37
+ # We are only ALWAYS interested in the pooled output of the final text encoder
38
+ pooled_prompt_embeds = prompt_embeds[0]
39
+ prompt_embeds = prompt_embeds.hidden_states[-2]
40
+ bs_embed, seq_len, _ = prompt_embeds.shape
41
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
42
+ prompt_embeds_list.append(prompt_embeds)
43
+
44
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
45
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
46
+ return prompt_embeds, pooled_prompt_embeds
47
+
48
+
49
+ def compute_embeddings(
50
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True,
51
+ device='cuda'
52
+ ):
53
+ target_size = (1024, 1024)
54
+ original_sizes = original_sizes #list(map(list, zip(*original_sizes)))
55
+ crops_coords_top_left = crop_coords #list(map(list, zip(*crop_coords)))
56
+
57
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
58
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
59
+
60
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
61
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
62
+ )
63
+ add_text_embeds = pooled_prompt_embeds
64
+
65
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
66
+ add_time_ids = list(target_size)
67
+ add_time_ids = torch.tensor([add_time_ids])
68
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
69
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
70
+ add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype)
71
+
72
+ prompt_embeds = prompt_embeds.to(device)
73
+ add_text_embeds = add_text_embeds.to(device)
74
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
75
+
76
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
77
+
78
+ def extract_into_tensor(a, t, x_shape):
79
+ b, *_ = t.shape
80
+ out = a.gather(-1, t)
81
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
82
+
83
+
84
+ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
85
+ """
86
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
87
+
88
+ Args:
89
+ timesteps (`torch.Tensor`):
90
+ generate embedding vectors at these timesteps
91
+ embedding_dim (`int`, *optional*, defaults to 512):
92
+ dimension of the embeddings to generate
93
+ dtype:
94
+ data type of the generated embeddings
95
+
96
+ Returns:
97
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
98
+ """
99
+ assert len(w.shape) == 1
100
+ w = w * 1000.0
101
+
102
+ half_dim = embedding_dim // 2
103
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
104
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
105
+ emb = w.to(dtype)[:, None] * emb[None, :]
106
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
107
+ if embedding_dim % 2 == 1: # zero pad
108
+ emb = torch.nn.functional.pad(emb, (0, 1))
109
+ assert emb.shape == (w.shape[0], embedding_dim)
110
+ return emb
111
+
112
+ def predicted_origin(model_output, timesteps, boundary_timesteps, sample, prediction_type, alphas, sigmas):
113
+ sigmas_s = extract_into_tensor(sigmas, boundary_timesteps, sample.shape)
114
+ alphas_s = extract_into_tensor(alphas, boundary_timesteps, sample.shape)
115
+
116
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
117
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
118
+
119
+ # Set hard boundaries to ensure equivalence with forward (direct) CD
120
+ alphas_s[boundary_timesteps == 0] = 1.0
121
+ sigmas_s[boundary_timesteps == 0] = 0.0
122
+
123
+ if prediction_type == "epsilon":
124
+ pred_x_0 = (sample - sigmas * model_output) / alphas # x0 prediction
125
+ pred_x_0 = alphas_s * pred_x_0 + sigmas_s * model_output # Euler step to the boundary step
126
+ elif prediction_type == "v_prediction":
127
+ assert boundary_timesteps == 0, "v_prediction does not support multiple endpoints at the moment"
128
+ pred_x_0 = alphas * sample - sigmas * model_output
129
+ else:
130
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
131
+
132
+ return pred_x_0
133
+
134
+
135
+ class DDIMSolver:
136
+ def __init__(
137
+ self, alpha_cumprods, timesteps=1000, ddim_timesteps=50,
138
+ num_endpoints=1, num_inverse_endpoints=1,
139
+ max_inverse_timestep_index=49,
140
+ endpoints=None, inverse_endpoints=None
141
+ ):
142
+ # DDIM sampling parameters
143
+ step_ratio = timesteps // ddim_timesteps
144
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(
145
+ np.int64) - 1 # [19, ..., 999]
146
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
147
+ self.ddim_alpha_cumprods_prev = np.asarray(
148
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
149
+ )
150
+ self.ddim_alpha_cumprods_next = np.asarray(
151
+ alpha_cumprods[self.ddim_timesteps[1:]].tolist() + [0.0]
152
+ )
153
+ # convert to torch tensors
154
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
155
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
156
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
157
+ self.ddim_alpha_cumprods_next = torch.from_numpy(self.ddim_alpha_cumprods_next)
158
+
159
+ # Set endpoints for direct CTM
160
+ if endpoints is None:
161
+ timestep_interval = ddim_timesteps // num_endpoints + int(ddim_timesteps % num_endpoints > 0)
162
+ endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1
163
+ self.endpoints = torch.tensor([0] + self.ddim_timesteps[endpoint_idxs].tolist())
164
+ else:
165
+ self.endpoints = torch.tensor([int(endpoint) for endpoint in endpoints.split(',')])
166
+ assert len(self.endpoints) == num_endpoints
167
+
168
+ # Set endpoints for inverse CTM
169
+ if inverse_endpoints is None:
170
+ timestep_interval = ddim_timesteps // num_inverse_endpoints + int(
171
+ ddim_timesteps % num_inverse_endpoints > 0)
172
+ inverse_endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1
173
+ inverse_endpoint_idxs = torch.tensor(inverse_endpoint_idxs.tolist() + [max_inverse_timestep_index])
174
+ self.inverse_endpoints = self.ddim_timesteps[inverse_endpoint_idxs]
175
+ else:
176
+ self.inverse_endpoints = torch.tensor([int(endpoint) for endpoint in inverse_endpoints.split(',')])
177
+ assert len(self.inverse_endpoints) == num_inverse_endpoints
178
+
179
+ def to(self, device):
180
+ self.endpoints = self.endpoints.to(device)
181
+ self.inverse_endpoints = self.inverse_endpoints.to(device)
182
+
183
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
184
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
185
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
186
+ self.ddim_alpha_cumprods_next = self.ddim_alpha_cumprods_next.to(device)
187
+ return self
188
+
189
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
190
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
191
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
192
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
193
+ return x_prev
194
+
195
+ def inverse_ddim_step(self, pred_x0, pred_noise, timestep_index):
196
+ alpha_cumprod_next = extract_into_tensor(self.ddim_alpha_cumprods_next, timestep_index, pred_x0.shape)
197
+ dir_xt = (1.0 - alpha_cumprod_next).sqrt() * pred_noise
198
+ x_next = alpha_cumprod_next.sqrt() * pred_x0 + dir_xt
199
+ return x_next
200
+ # ------------------------------------------------------------------------
201
+
202
+ # Distillation specific
203
+ # ------------------------------------------------------------------------
204
+ def inverse_sample_deterministic(
205
+ pipe,
206
+ images,
207
+ prompt,
208
+ generator=None,
209
+ num_scales=50,
210
+ num_inference_steps=1,
211
+ timesteps=None,
212
+ start_timestep=19,
213
+ max_inverse_timestep_index=49,
214
+ return_start_latent=False,
215
+ guidance_scale=None, # Used only if the student has w_embedding
216
+ compute_embeddings_fn=None,
217
+ is_sdxl=False,
218
+ inverse_endpoints=None,
219
+ seed=0,
220
+ ):
221
+ # assert isinstance(pipe, StableDiffusionImg2ImgPipeline), f"Does not support the pipeline {type(pipe)}"
222
+
223
+ if prompt is not None and isinstance(prompt, str):
224
+ batch_size = 1
225
+ elif prompt is not None and isinstance(prompt, list):
226
+ batch_size = len(prompt)
227
+
228
+ device = pipe._execution_device
229
+
230
+ # Prepare text embeddings
231
+ if compute_embeddings_fn is not None:
232
+ if is_sdxl:
233
+ orig_size = [(1024, 1024)] * len(prompt)
234
+ crop_coords = [(0, 0)] * len(prompt)
235
+ encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords)
236
+ prompt_embeds = encoded_text.pop("prompt_embeds")
237
+ else:
238
+ prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"]
239
+ encoded_text = {}
240
+ prompt_embeds = prompt_embeds.to(pipe.unet.dtype)
241
+ else:
242
+ prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0]
243
+ encoded_text = {}
244
+ assert prompt_embeds.dtype == pipe.unet.dtype
245
+
246
+ # Prepare the DDIM solver
247
+ endpoints = ','.join(['0'] + inverse_endpoints.split(',')[:-1]) if inverse_endpoints is not None else None
248
+ solver = DDIMSolver(
249
+ pipe.scheduler.alphas_cumprod.cpu().numpy(),
250
+ timesteps=pipe.scheduler.num_train_timesteps,
251
+ ddim_timesteps=num_scales,
252
+ num_endpoints=num_inference_steps,
253
+ num_inverse_endpoints=num_inference_steps,
254
+ max_inverse_timestep_index=max_inverse_timestep_index,
255
+ endpoints=endpoints,
256
+ inverse_endpoints=inverse_endpoints
257
+ ).to(device)
258
+
259
+ if timesteps is None:
260
+ timesteps = solver.inverse_endpoints.flip(0)
261
+ boundary_timesteps = solver.endpoints.flip(0)
262
+ else:
263
+ timesteps, boundary_timesteps = timesteps, timesteps
264
+ boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]]
265
+ boundary_timesteps[-1] = 999
266
+ timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps)
267
+
268
+ alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device)
269
+ sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device)
270
+
271
+ # 5. Prepare latent variables
272
+ num_channels_latents = pipe.unet.config.in_channels
273
+ start_latents = pipe.prepare_latents(
274
+ images, timesteps[0], batch_size, 1, prompt_embeds.dtype, device,
275
+ generator=torch.Generator().manual_seed(seed),
276
+ )
277
+ latents = start_latents.clone()
278
+
279
+ if guidance_scale is not None:
280
+ w = torch.ones(batch_size) * guidance_scale
281
+ w_embedding = guidance_scale_embedding(w, embedding_dim=512)
282
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
283
+ else:
284
+ w_embedding = None
285
+
286
+ for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)):
287
+ # predict the noise residual
288
+ noise_pred = pipe.unet(
289
+ latents.to(prompt_embeds.dtype),
290
+ t,
291
+ encoder_hidden_states=prompt_embeds,
292
+ return_dict=False,
293
+ timestep_cond=w_embedding,
294
+ added_cond_kwargs=encoded_text,
295
+ )[0]
296
+
297
+ latents = predicted_origin(
298
+ noise_pred,
299
+ torch.tensor([t] * len(latents), device=device),
300
+ torch.tensor([s] * len(latents), device=device),
301
+ latents,
302
+ pipe.scheduler.config.prediction_type,
303
+ alpha_schedule,
304
+ sigma_schedule,
305
+ ).to(prompt_embeds.dtype)
306
+
307
+ if return_start_latent:
308
+ return latents, start_latents
309
+ else:
310
+ return latents
311
+
312
+
313
+ def linear_schedule_old(t, guidance_scale, tau1, tau2):
314
+ t = t / 1000
315
+ if t <= tau1:
316
+ gamma = 1.0
317
+ elif t >= tau2:
318
+ gamma = 0.0
319
+ else:
320
+ gamma = (tau2 - t) / (tau2 - tau1)
321
+ return gamma * guidance_scale
322
+
323
+
324
+ @torch.no_grad()
325
+ def sample_deterministic(
326
+ pipe,
327
+ prompt,
328
+ latents=None,
329
+ generator=None,
330
+ num_scales=50,
331
+ num_inference_steps=1,
332
+ timesteps=None,
333
+ start_timestep=19,
334
+ max_inverse_timestep_index=49,
335
+ return_latent=False,
336
+ guidance_scale=None, # Used only if the student has w_embedding
337
+ compute_embeddings_fn=None,
338
+ is_sdxl=False,
339
+ endpoints=None,
340
+ use_dynamic_guidance=False,
341
+ tau1=0.7,
342
+ tau2=0.7,
343
+ amplify_prompt=None,
344
+ ):
345
+ # assert isinstance(pipe, StableDiffusionPipeline), f"Does not support the pipeline {type(pipe)}"
346
+ height = pipe.unet.config.sample_size * pipe.vae_scale_factor
347
+ width = pipe.unet.config.sample_size * pipe.vae_scale_factor
348
+
349
+ # 1. Define call parameters
350
+ if prompt is not None and isinstance(prompt, str):
351
+ batch_size = 1
352
+ elif prompt is not None and isinstance(prompt, list):
353
+ batch_size = len(prompt)
354
+
355
+ device = pipe._execution_device
356
+
357
+ # Prepare text embeddings
358
+ if compute_embeddings_fn is not None:
359
+ if is_sdxl:
360
+ orig_size = [(1024, 1024)] * len(prompt)
361
+ crop_coords = [(0, 0)] * len(prompt)
362
+ encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords)
363
+ prompt_embeds = encoded_text.pop("prompt_embeds")
364
+ if amplify_prompt is not None:
365
+ orig_size = [(1024, 1024)] * len(amplify_prompt)
366
+ crop_coords = [(0, 0)] * len(amplify_prompt)
367
+ encoded_text_old = compute_embeddings_fn(amplify_prompt, orig_size, crop_coords)
368
+ amplify_prompt_embeds = encoded_text_old.pop("prompt_embeds")
369
+ else:
370
+ prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"]
371
+ encoded_text = {}
372
+ prompt_embeds = prompt_embeds.to(pipe.unet.dtype)
373
+ else:
374
+ prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0]
375
+ encoded_text = {}
376
+ assert prompt_embeds.dtype == pipe.unet.dtype
377
+
378
+ # Prepare the DDIM solver
379
+ inverse_endpoints = ','.join(endpoints.split(',')[1:] + ['999']) if endpoints is not None else None
380
+ solver = DDIMSolver(
381
+ pipe.scheduler.alphas_cumprod.numpy(),
382
+ timesteps=pipe.scheduler.num_train_timesteps,
383
+ ddim_timesteps=num_scales,
384
+ num_endpoints=num_inference_steps,
385
+ num_inverse_endpoints=num_inference_steps,
386
+ max_inverse_timestep_index=max_inverse_timestep_index,
387
+ endpoints=endpoints,
388
+ inverse_endpoints=inverse_endpoints
389
+ ).to(device)
390
+
391
+ prompt_embeds_init = copy.deepcopy(prompt_embeds)
392
+
393
+ if timesteps is None:
394
+ timesteps = solver.inverse_endpoints.flip(0)
395
+ boundary_timesteps = solver.endpoints.flip(0)
396
+ else:
397
+ timesteps, boundary_timesteps = copy.deepcopy(timesteps), copy.deepcopy(timesteps)
398
+ timesteps.reverse()
399
+ boundary_timesteps.reverse()
400
+ boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]]
401
+ boundary_timesteps[-1] = 0
402
+ timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps)
403
+
404
+ alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device)
405
+ sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device)
406
+
407
+ # 5. Prepare latent variables
408
+ if latents is None:
409
+ num_channels_latents = pipe.unet.config.in_channels
410
+ latents = pipe.prepare_latents(
411
+ batch_size,
412
+ num_channels_latents,
413
+ height,
414
+ width,
415
+ prompt_embeds.dtype,
416
+ device,
417
+ generator,
418
+ None,
419
+ )
420
+ assert latents.dtype == pipe.unet.dtype
421
+ else:
422
+ latents = latents.to(prompt_embeds.dtype)
423
+
424
+ if guidance_scale is not None:
425
+ w = torch.ones(batch_size) * guidance_scale
426
+ w_embedding = guidance_scale_embedding(w, embedding_dim=512)
427
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
428
+ else:
429
+ w_embedding = None
430
+
431
+ for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)):
432
+ if use_dynamic_guidance:
433
+ if not isinstance(t, int):
434
+ t_item = t.item()
435
+ if t_item > tau1 * 1000 and amplify_prompt is not None:
436
+ prompt_embeds = amplify_prompt_embeds
437
+ else:
438
+ prompt_embeds = prompt_embeds_init
439
+ guidance_scale = linear_schedule_old(t_item, w, tau1=tau1, tau2=tau2)
440
+ guidance_scale_tensor = torch.tensor([guidance_scale] * len(latents))
441
+ w_embedding = guidance_scale_embedding(guidance_scale_tensor, embedding_dim=512)
442
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
443
+
444
+ # predict the noise residual
445
+ noise_pred = pipe.unet(
446
+ latents,
447
+ t,
448
+ encoder_hidden_states=prompt_embeds,
449
+ cross_attention_kwargs=None,
450
+ return_dict=False,
451
+ timestep_cond=w_embedding,
452
+ added_cond_kwargs=encoded_text,
453
+ )[0]
454
+
455
+ latents = predicted_origin(
456
+ noise_pred,
457
+ torch.tensor([t] * len(noise_pred)).to(device),
458
+ torch.tensor([s] * len(noise_pred)).to(device),
459
+ latents,
460
+ pipe.scheduler.config.prediction_type,
461
+ alpha_schedule,
462
+ sigma_schedule,
463
+ ).to(pipe.unet.dtype)
464
+
465
+ pipe.vae.to(torch.float32)
466
+ image = pipe.vae.decode(latents.to(torch.float32) / pipe.vae.config.scaling_factor, return_dict=False)[0]
467
+ do_denormalize = [True] * image.shape[0]
468
+ image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
469
+
470
+ if return_latent:
471
+ return image, latents
472
+ else:
473
+ return image
474
+ # ------------------------------------------------------------------------
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  accelerate
2
  diffusers
3
  invisible_watermark
 
4
  torch
5
  transformers
6
  xformers
 
1
  accelerate
2
  diffusers
3
  invisible_watermark
4
+ peft
5
  torch
6
  transformers
7
  xformers