Tonic commited on
Commit
2551487
β€’
1 Parent(s): 70b1b69

Create inversion.py

Browse files
Files changed (1) hide show
  1. inversion.py +125 -0
inversion.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+ from typing import Callable
18
+ from diffusers import StableDiffusionXLPipeline
19
+ import torch
20
+ from tqdm import tqdm
21
+ import numpy as np
22
+
23
+
24
+ T = torch.Tensor
25
+ TN = T | None
26
+ InversionCallback = Callable[[StableDiffusionXLPipeline, int, T, dict[str, T]], dict[str, T]]
27
+
28
+
29
+ def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
30
+ # Tokenize text and get embeddings
31
+ text_inputs = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
32
+ text_input_ids = text_inputs.input_ids
33
+
34
+ with torch.no_grad():
35
+ prompt_embeds = text_encoder(
36
+ text_input_ids.to(device),
37
+ output_hidden_states=True,
38
+ )
39
+
40
+ pooled_prompt_embeds = prompt_embeds[0]
41
+ prompt_embeds = prompt_embeds.hidden_states[-2]
42
+ if prompt == '':
43
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
44
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
45
+ return negative_prompt_embeds, negative_pooled_prompt_embeds
46
+ return prompt_embeds, pooled_prompt_embeds
47
+
48
+
49
+ def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
50
+ device = model._execution_device
51
+ prompt_embeds, pooled_prompt_embeds, = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device)
52
+ prompt_embeds_2, pooled_prompt_embeds2, = _get_text_embeddings( prompt, model.tokenizer_2, model.text_encoder_2, device)
53
+ prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
54
+ text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
55
+ add_time_ids = model._get_add_time_ids((1024, 1024), (0, 0), (1024, 1024), torch.float16,
56
+ text_encoder_projection_dim).to(device)
57
+ added_cond_kwargs = {"text_embeds": pooled_prompt_embeds2, "time_ids": add_time_ids}
58
+ return added_cond_kwargs, prompt_embeds
59
+
60
+
61
+ def _encode_text_sdxl_with_negative(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
62
+ added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
63
+ added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(model, "")
64
+ prompt_embeds = torch.cat((prompt_embeds_uncond, prompt_embeds, ))
65
+ added_cond_kwargs = {"text_embeds": torch.cat((added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])),
66
+ "time_ids": torch.cat((added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])),}
67
+ return added_cond_kwargs, prompt_embeds
68
+
69
+
70
+ def _encode_image(model: StableDiffusionXLPipeline, image: np.ndarray) -> T:
71
+ model.vae.to(dtype=torch.float32)
72
+ image = torch.from_numpy(image).float() / 255.
73
+ image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)
74
+ latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor
75
+ model.vae.to(dtype=torch.float16)
76
+ return latent
77
+
78
+
79
+ def _next_step(model: StableDiffusionXLPipeline, model_output: T, timestep: int, sample: T) -> T:
80
+ timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
81
+ alpha_prod_t = model.scheduler.alphas_cumprod[int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod
82
+ alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]
83
+ beta_prod_t = 1 - alpha_prod_t
84
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
85
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
86
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
87
+ return next_sample
88
+
89
+
90
+ def _get_noise_pred(model: StableDiffusionXLPipeline, latent: T, t: T, context: T, guidance_scale: float, added_cond_kwargs: dict[str, T]):
91
+ latents_input = torch.cat([latent] * 2)
92
+ noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"]
93
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
94
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
95
+ # latents = next_step(model, noise_pred, t, latent)
96
+ return noise_pred
97
+
98
+
99
+ def _ddim_loop(model: StableDiffusionXLPipeline, z0, prompt, guidance_scale) -> T:
100
+ all_latent = [z0]
101
+ added_cond_kwargs, text_embedding = _encode_text_sdxl_with_negative(model, prompt)
102
+ latent = z0.clone().detach().half()
103
+ for i in tqdm(range(model.scheduler.num_inference_steps)):
104
+ t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
105
+ noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale, added_cond_kwargs)
106
+ latent = _next_step(model, noise_pred, t, latent)
107
+ all_latent.append(latent)
108
+ return torch.cat(all_latent).flip(0)
109
+
110
+
111
+ def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]:
112
+
113
+ def callback_on_step_end(pipeline: StableDiffusionXLPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[str, T]:
114
+ latents = callback_kwargs['latents']
115
+ latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
116
+ return {'latents': latents}
117
+ return zts[offset], callback_on_step_end
118
+
119
+
120
+ @torch.no_grad()
121
+ def ddim_inversion(model: StableDiffusionXLPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, guidance_scale,) -> T:
122
+ z0 = _encode_image(model, x0)
123
+ model.scheduler.set_timesteps(num_inference_steps, device=z0.device)
124
+ zs = _ddim_loop(model, z0, prompt, guidance_scale)
125
+ return zs