Spaces:
Runtime error
Runtime error
| import inspect | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as nnf | |
| from PIL import Image | |
| from torch.optim.adam import Adam | |
| from tqdm import tqdm | |
| from diffusers import StableDiffusionPipeline | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps=None, | |
| device=None, | |
| timesteps=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, | |
| `timesteps` must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default | |
| timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` | |
| must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| class NullTextPipeline(StableDiffusionPipeline): | |
| def get_noise_pred(self, latents, t, context): | |
| latents_input = torch.cat([latents] * 2) | |
| guidance_scale = 7.5 | |
| noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] | |
| noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | |
| latents = self.prev_step(noise_pred, t, latents) | |
| return latents | |
| def get_noise_pred_single(self, latents, t, context): | |
| noise_pred = self.unet(latents, t, encoder_hidden_states=context)["sample"] | |
| return noise_pred | |
| def image2latent(self, image_path): | |
| image = Image.open(image_path).convert("RGB") | |
| image = np.array(image) | |
| image = torch.from_numpy(image).float() / 127.5 - 1 | |
| image = image.permute(2, 0, 1).unsqueeze(0).to(self.device) | |
| latents = self.vae.encode(image)["latent_dist"].mean | |
| latents = latents * 0.18215 | |
| return latents | |
| def latent2image(self, latents): | |
| latents = 1 / 0.18215 * latents.detach() | |
| image = self.vae.decode(latents)["sample"].detach() | |
| image = self.processor.postprocess(image, output_type="pil")[0] | |
| return image | |
| def prev_step(self, model_output, timestep, sample): | |
| prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | |
| alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
| alpha_prod_t_prev = ( | |
| self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | |
| ) | |
| beta_prod_t = 1 - alpha_prod_t | |
| pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 | |
| pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output | |
| prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction | |
| return prev_sample | |
| def next_step(self, model_output, timestep, sample): | |
| timestep, next_timestep = ( | |
| min(timestep - self.scheduler.config.num_train_timesteps // self.num_inference_steps, 999), | |
| timestep, | |
| ) | |
| alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod | |
| alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] | |
| beta_prod_t = 1 - alpha_prod_t | |
| next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 | |
| next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output | |
| next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction | |
| return next_sample | |
| def null_optimization(self, latents, context, num_inner_steps, epsilon): | |
| uncond_embeddings, cond_embeddings = context.chunk(2) | |
| uncond_embeddings_list = [] | |
| latent_cur = latents[-1] | |
| bar = tqdm(total=num_inner_steps * self.num_inference_steps) | |
| for i in range(self.num_inference_steps): | |
| uncond_embeddings = uncond_embeddings.clone().detach() | |
| uncond_embeddings.requires_grad = True | |
| optimizer = Adam([uncond_embeddings], lr=1e-2 * (1.0 - i / 100.0)) | |
| latent_prev = latents[len(latents) - i - 2] | |
| t = self.scheduler.timesteps[i] | |
| with torch.no_grad(): | |
| noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) | |
| for j in range(num_inner_steps): | |
| noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) | |
| noise_pred = noise_pred_uncond + 7.5 * (noise_pred_cond - noise_pred_uncond) | |
| latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) | |
| loss = nnf.mse_loss(latents_prev_rec, latent_prev) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| loss_item = loss.item() | |
| bar.update() | |
| if loss_item < epsilon + i * 2e-5: | |
| break | |
| for j in range(j + 1, num_inner_steps): | |
| bar.update() | |
| uncond_embeddings_list.append(uncond_embeddings[:1].detach()) | |
| with torch.no_grad(): | |
| context = torch.cat([uncond_embeddings, cond_embeddings]) | |
| latent_cur = self.get_noise_pred(latent_cur, t, context) | |
| bar.close() | |
| return uncond_embeddings_list | |
| def ddim_inversion_loop(self, latent, context): | |
| self.scheduler.set_timesteps(self.num_inference_steps) | |
| _, cond_embeddings = context.chunk(2) | |
| all_latent = [latent] | |
| latent = latent.clone().detach() | |
| with torch.no_grad(): | |
| for i in range(0, self.num_inference_steps): | |
| t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1] | |
| noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)["sample"] | |
| latent = self.next_step(noise_pred, t, latent) | |
| all_latent.append(latent) | |
| return all_latent | |
| def get_context(self, prompt): | |
| uncond_input = self.tokenizer( | |
| [""], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
| text_input = self.tokenizer( | |
| [prompt], | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
| context = torch.cat([uncond_embeddings, text_embeddings]) | |
| return context | |
| def invert( | |
| self, image_path: str, prompt: str, num_inner_steps=10, early_stop_epsilon=1e-6, num_inference_steps=50 | |
| ): | |
| self.num_inference_steps = num_inference_steps | |
| context = self.get_context(prompt) | |
| latent = self.image2latent(image_path) | |
| ddim_latents = self.ddim_inversion_loop(latent, context) | |
| if os.path.exists(image_path + ".pt"): | |
| uncond_embeddings = torch.load(image_path + ".pt") | |
| else: | |
| uncond_embeddings = self.null_optimization(ddim_latents, context, num_inner_steps, early_stop_epsilon) | |
| uncond_embeddings = torch.stack(uncond_embeddings, 0) | |
| torch.save(uncond_embeddings, image_path + ".pt") | |
| return ddim_latents[-1], uncond_embeddings | |
| def __call__( | |
| self, | |
| prompt, | |
| uncond_embeddings, | |
| inverted_latent, | |
| num_inference_steps: int = 50, | |
| timesteps=None, | |
| guidance_scale=7.5, | |
| negative_prompt=None, | |
| num_images_per_prompt=1, | |
| generator=None, | |
| latents=None, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| output_type="pil", | |
| ): | |
| self._guidance_scale = guidance_scale | |
| # 0. Default height and width to unet | |
| height = self.unet.config.sample_size * self.vae_scale_factor | |
| width = self.unet.config.sample_size * self.vae_scale_factor | |
| # to deal with lora scaling and other possible forward hook | |
| callback_steps = None | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| height, | |
| width, | |
| callback_steps, | |
| negative_prompt, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| ) | |
| # 2. Define call parameter | |
| device = self._execution_device | |
| # 3. Encode input prompt | |
| prompt_embeds, _ = self.encode_prompt( | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| self.do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| ) | |
| # 4. Prepare timesteps | |
| timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) | |
| latents = inverted_latent | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=uncond_embeddings[i])["sample"] | |
| noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)["sample"] | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| progress_bar.update() | |
| if not output_type == "latent": | |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ | |
| 0 | |
| ] | |
| else: | |
| image = latents | |
| image = self.image_processor.postprocess( | |
| image, output_type=output_type, do_denormalize=[True] * image.shape[0] | |
| ) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False) | |