noamelata commited on
Commit
82ad0f2
1 Parent(s): e78fed0

initial commit

Browse files
Files changed (4) hide show
  1. NestedPipeline.py +246 -0
  2. NestedScheduler.py +180 -0
  3. app.py +53 -0
  4. requirements.txt +211 -0
NestedPipeline.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ from diffusers.utils import replace_example_docstring
6
+ from transformers import CLIPTokenizer
7
+
8
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, EXAMPLE_DOC_STRING
9
+
10
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
11
+
12
+ from NestedScheduler import NestedScheduler
13
+
14
+
15
+
16
+
17
+ class NestedStableDiffusionPipeline(StableDiffusionPipeline):
18
+ r"""
19
+ Pipeline for text-to-image generation using Nested Stable Diffusion.
20
+
21
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
22
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
23
+
24
+ In addition the pipeline inherits the following loading methods:
25
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
26
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
27
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
28
+
29
+ as well as the following saving methods:
30
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
31
+
32
+ Args:
33
+ vae ([`AutoencoderKL`]):
34
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
35
+ text_encoder ([`CLIPTextModel`]):
36
+ Frozen text-encoder. Stable Diffusion uses the text portion of
37
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
38
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
39
+ tokenizer (`CLIPTokenizer`):
40
+ Tokenizer of class
41
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
42
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
43
+ scheduler ([`SchedulerMixin`]):
44
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
45
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
46
+ safety_checker ([`StableDiffusionSafetyChecker`]):
47
+ Classification module that estimates whether generated images could be considered offensive or harmful.
48
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
49
+ feature_extractor ([`CLIPImageProcessor`]):
50
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
51
+ """
52
+
53
+ @torch.no_grad()
54
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
55
+ def __call__(
56
+ self,
57
+ prompt: Union[str, List[str]] = None,
58
+ height: Optional[int] = None,
59
+ width: Optional[int] = None,
60
+ num_inference_steps: int = 5,
61
+ num_inner_steps: int = 20,
62
+ guidance_scale: float = 7.5,
63
+ negative_prompt: Optional[Union[str, List[str]]] = None,
64
+ num_images_per_prompt: Optional[int] = 1,
65
+ eta: float = 0.0,
66
+ inner_eta: float = 0.85,
67
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
68
+ latents: Optional[torch.FloatTensor] = None,
69
+ prompt_embeds: Optional[torch.FloatTensor] = None,
70
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
71
+ output_type: Optional[str] = "pil",
72
+ return_dict: bool = True,
73
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
74
+ callback_steps: int = 1,
75
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
76
+ coroutine_mode=True):
77
+ r"""
78
+ Function invoked when calling the pipeline for generation.
79
+
80
+ Args:
81
+ prompt (`str` or `List[str]`, *optional*):
82
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
83
+ instead.
84
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
85
+ The height in pixels of the generated image.
86
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
87
+ The width in pixels of the generated image.
88
+ num_inference_steps (`int`, *optional*, defaults to 5):
89
+ The number of outer denoising steps.
90
+ num_inner_steps (`int`, *optional*, defaults to 20):
91
+ The number of inner denoising steps.
92
+ guidance_scale (`float`, *optional*, defaults to 7.5):
93
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
94
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
95
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
96
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
97
+ usually at the expense of lower image quality.
98
+ negative_prompt (`str` or `List[str]`, *optional*):
99
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
100
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
101
+ less than `1`).
102
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
103
+ The number of images to generate per prompt.
104
+ eta (`float`, *optional*, defaults to 0.0):
105
+ Corresponds to parameter eta (η) in the outer diffusion process
106
+ inner_eta (`float`, *optional*, defaults to 0.85):
107
+ Corresponds to parameter eta (η) in the inner diffusion process
108
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
109
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
110
+ to make generation deterministic.
111
+ latents (`torch.FloatTensor`, *optional*):
112
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
113
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
114
+ tensor will ge generated by sampling using the supplied random `generator`.
115
+ prompt_embeds (`torch.FloatTensor`, *optional*):
116
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
117
+ provided, text embeddings will be generated from `prompt` input argument.
118
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
119
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
120
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
121
+ argument.
122
+ output_type (`str`, *optional*, defaults to `"pil"`):
123
+ The output format of the generate image. Choose between
124
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
125
+ return_dict (`bool`, *optional*, defaults to `True`):
126
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
127
+ plain tuple.
128
+ callback (`Callable`, *optional*):
129
+ A function that will be called every `callback_steps` steps during inference. The function will be
130
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
131
+ callback_steps (`int`, *optional*, defaults to 1):
132
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
133
+ called at every step.
134
+ cross_attention_kwargs (`dict`, *optional*):
135
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
136
+ `self.processor` in
137
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
138
+
139
+ Examples:
140
+
141
+ Returns:
142
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
143
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
144
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
145
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
146
+ (nsfw) content, according to the `safety_checker`.
147
+ """
148
+ # 0. Default height and width to unet
149
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
150
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
151
+
152
+ # 1. Check inputs. Raise error if not correct
153
+ self.check_inputs(
154
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
155
+ )
156
+
157
+ # 2. Define call parameters
158
+ if prompt is not None and isinstance(prompt, str):
159
+ batch_size = 1
160
+ elif prompt is not None and isinstance(prompt, list):
161
+ batch_size = len(prompt)
162
+ else:
163
+ batch_size = prompt_embeds.shape[0]
164
+
165
+ device = self._execution_device
166
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
167
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
168
+ # corresponds to doing no classifier free guidance.
169
+ do_classifier_free_guidance = guidance_scale > 1.0
170
+
171
+ # 3. Encode input prompt
172
+ prompt_embeds = self._encode_prompt(
173
+ prompt,
174
+ device,
175
+ num_images_per_prompt,
176
+ do_classifier_free_guidance,
177
+ negative_prompt,
178
+ prompt_embeds=prompt_embeds,
179
+ negative_prompt_embeds=negative_prompt_embeds,
180
+ )
181
+
182
+ # 4. Prepare timesteps
183
+ self.scheduler.set_timesteps(num_inference_steps + 1, device=device)
184
+ timesteps = self.scheduler.timesteps[:-1]
185
+
186
+ # 5. Prepare latent variables
187
+ num_channels_latents = self.unet.config.in_channels
188
+ latents = self.prepare_latents(
189
+ batch_size * num_images_per_prompt,
190
+ num_channels_latents,
191
+ height,
192
+ width,
193
+ prompt_embeds.dtype,
194
+ device,
195
+ generator,
196
+ latents,
197
+ )
198
+
199
+ # 6. Prepare extra step kwargs.
200
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
201
+ inner_extra_step_kwargs = self.prepare_extra_step_kwargs(generator, inner_eta)
202
+
203
+ # 7. Denoising loop
204
+ outer_latents = latents.clone()
205
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
206
+
207
+ # running the outer diffusion procees
208
+ anytime_latent = outer_latents.clone()
209
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
210
+ for i, t in enumerate(timesteps):
211
+ # creating the inner diffusion process
212
+ self.inner_scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012,
213
+ beta_schedule="scaled_linear", clip_sample=False,
214
+ set_alpha_to_one=False, thresholding=False)
215
+ self.inner_scheduler.set_timesteps(num_inner_steps, max_timestep=t.item(), device=device)
216
+ inner_timesteps = self.inner_scheduler.timesteps
217
+ latents = outer_latents.clone()
218
+ # running the inner diffusion procees
219
+
220
+ for j, t_tag in enumerate(inner_timesteps):
221
+ yield (i, j, self.decode_latents(anytime_latent))
222
+ # expand the latents if we are doing classifier free guidance
223
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
224
+ latent_model_input = self.inner_scheduler.scale_model_input(latent_model_input, t_tag)
225
+
226
+ # predict the noise residual
227
+ noise_pred = self.unet(latent_model_input, t_tag, encoder_hidden_states=prompt_embeds).sample
228
+
229
+ # perform guidance
230
+ if do_classifier_free_guidance:
231
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
232
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
233
+
234
+ latents = self.inner_scheduler.step(noise_pred, t_tag, latents, **inner_extra_step_kwargs).prev_sample
235
+
236
+ anytime_latent = latents.clone()
237
+ # compute the previous noisy sample x_t -> x_t-1
238
+ outer_latents = self.scheduler.step(latents, t, outer_latents, **extra_step_kwargs).prev_sample
239
+
240
+ # call the callback, if provided
241
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
242
+ progress_bar.update()
243
+ if callback is not None and i % callback_steps == 0:
244
+ callback(i, t, latents)
245
+
246
+ yield (i+1, j+1, self.decode_latents(outer_latents))
NestedScheduler.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DDIMScheduler
8
+
9
+ from diffusers.utils import BaseOutput
10
+
11
+
12
+ @dataclass
13
+ class NestedSchedulerOutput(BaseOutput):
14
+ """
15
+ Output class for the scheduler's step function output.
16
+
17
+ Args:
18
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
19
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
20
+ denoising loop.
21
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
22
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
23
+ `pred_original_sample` can be used to preview progress or for guidance.
24
+ """
25
+
26
+ prev_sample: torch.FloatTensor
27
+ pred_original_sample: Optional[torch.FloatTensor] = None
28
+
29
+
30
+
31
+ class NestedScheduler(DDIMScheduler):
32
+
33
+ def set_timesteps(self, num_inference_steps: int, max_timestep: int = 1000, device: Union[str, torch.device] = None):
34
+ """
35
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
36
+
37
+ Args:
38
+ num_inference_steps (`int`):
39
+ the number of diffusion steps used when generating figures with a pre-trained model.
40
+ max_timestep (`int`):
41
+ the highest timestep to use for choosing the timesteps
42
+ """
43
+
44
+ if num_inference_steps > self.config.num_train_timesteps:
45
+ raise ValueError(
46
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
47
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
48
+ f" maximal {self.config.num_train_timesteps} timesteps."
49
+ )
50
+
51
+ self.num_inference_steps = num_inference_steps
52
+ max_timestep = min(self.config.num_train_timesteps - 1, max_timestep)
53
+ timesteps = np.linspace(1, max_timestep, min(num_inference_steps, max_timestep)).round()[::-1].copy().astype(np.int64)
54
+ self.timesteps = torch.from_numpy(timesteps).to(device)
55
+
56
+ def step(
57
+ self,
58
+ model_output: torch.FloatTensor,
59
+ timestep: int,
60
+ sample: torch.FloatTensor,
61
+ eta: float = 0.0,
62
+ use_clipped_model_output: bool = False,
63
+ generator=None,
64
+ variance_noise: Optional[torch.FloatTensor] = None,
65
+ return_dict: bool = True,
66
+ override_prediction_type = '',
67
+ ) -> Union[NestedSchedulerOutput, Tuple]:
68
+ """
69
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
70
+ process from the learned model outputs (most often the predicted noise).
71
+
72
+ Args:
73
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
74
+ timestep (`int`): current discrete timestep in the diffusion chain.
75
+ sample (`torch.FloatTensor`):
76
+ current instance of sample being created by diffusion process.
77
+ eta (`float`): weight of noise for added noise in diffusion step.
78
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
79
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
80
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
81
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
82
+ generator: random number generator.
83
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
84
+ can directly provide the noise for the variance itself. This is useful for methods such as
85
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
86
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
87
+
88
+ Returns:
89
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
90
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
91
+ returning a tuple, the first element is the sample tensor.
92
+
93
+ """
94
+ if self.num_inference_steps is None:
95
+ raise ValueError(
96
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
97
+ )
98
+
99
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
100
+ # Ideally, read DDIM paper in-detail understanding
101
+
102
+ # Notation (<variable name> -> <name in paper>
103
+ # - pred_noise_t -> e_theta(x_t, t)
104
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
105
+ # - std_dev_t -> sigma_t
106
+ # - eta -> η
107
+ # - pred_sample_direction -> "direction pointing to x_t"
108
+ # - pred_prev_sample -> "x_t-1"
109
+
110
+ # 1. get previous step value (=t-1)
111
+ # prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
112
+ cur_idx = (self.timesteps == timestep).nonzero().item()
113
+ prev_timestep = self.timesteps[cur_idx + 1] if cur_idx < len(self.timesteps) - 1 else 0
114
+
115
+ # 2. compute alphas, betas
116
+ alpha_prod_t = self.alphas_cumprod[timestep]
117
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
118
+
119
+ beta_prod_t = 1 - alpha_prod_t
120
+
121
+ # 3. compute predicted original sample from predicted noise also called
122
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
123
+ prediction_type = override_prediction_type if override_prediction_type else self.config.prediction_type
124
+ if prediction_type == "epsilon":
125
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
126
+ pred_epsilon = model_output
127
+ elif prediction_type == "sample":
128
+ pred_original_sample = model_output
129
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
130
+ elif prediction_type == "v_prediction":
131
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
132
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
133
+ else:
134
+ raise ValueError(
135
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
136
+ " `v_prediction`"
137
+ )
138
+
139
+ # 4. Clip or threshold "predicted x_0"
140
+ if self.config.thresholding:
141
+ pred_original_sample = self._threshold_sample(pred_original_sample)
142
+ elif self.config.clip_sample:
143
+ pred_original_sample = pred_original_sample.clamp(
144
+ -self.config.clip_sample_range, self.config.clip_sample_range
145
+ )
146
+
147
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
148
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
149
+ variance = self._get_variance(timestep, prev_timestep)
150
+ std_dev_t = eta * variance ** (0.5)
151
+
152
+ if use_clipped_model_output:
153
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
154
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
155
+
156
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
157
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
158
+
159
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
160
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
161
+
162
+ if eta > 0:
163
+ if variance_noise is not None and generator is not None:
164
+ raise ValueError(
165
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
166
+ " `variance_noise` stays `None`."
167
+ )
168
+
169
+ if variance_noise is None:
170
+ variance_noise = torch.randn(
171
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
172
+ )
173
+ variance = std_dev_t * variance_noise
174
+
175
+ prev_sample = prev_sample + variance
176
+
177
+ if not return_dict:
178
+ return (prev_sample,)
179
+
180
+ return NestedSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from random import randint
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from NestedPipeline import NestedStableDiffusionPipeline
9
+ from NestedScheduler import NestedScheduler
10
+
11
+
12
+ def run(prompt, outer, inner, random_seed, pipe):
13
+
14
+ seed = 24 if not random_seed else randint(0, 10000)
15
+ generator = torch.Generator(device).manual_seed(seed)
16
+ outer_diffusion = tqdm(range(outer), desc="Outer Diffusion")
17
+ inner_diffusion = tqdm(range(inner), desc="Inner Diffusion")
18
+
19
+ cur = [0, 0]
20
+ for i, j, im in pipe(prompt, num_inference_steps=outer, num_inner_steps=inner, generator=generator):
21
+ if cur[-1] != j:
22
+ inner_diffusion.update()
23
+ cur[-1] = j
24
+ if cur[0] != i and i != outer:
25
+ cur[0] = i
26
+ outer_diffusion.update()
27
+ cur[-1] = 0
28
+ inner_diffusion = tqdm(range(inner), desc="Inner Diffusion")
29
+ elif cur[0] != i:
30
+ outer_diffusion.update()
31
+ monospace_s, monospace_e = "<p style=\"font-family:'Lucida Console', monospace\">", "</p>"
32
+ yield f"{monospace_s}{outer_diffusion.__str__().replace(' ', '&nbsp;')}{monospace_e} \n {monospace_s}{inner_diffusion.__str__().replace(' ', '&nbsp;')}{monospace_e}", im[0]
33
+
34
+ if __name__ == "__main__":
35
+ scheduler = NestedScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
36
+ prediction_type='sample', clip_sample=False, set_alpha_to_one=False)
37
+ pipe = NestedStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16",
38
+ torch_dtype=torch.float16, scheduler=scheduler)
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ pipe.to(device)
41
+ interface = partial(run, pipe=pipe)
42
+ demo = gr.Interface(
43
+ fn=interface,
44
+ inputs=[gr.Textbox(value="a photograph of a nest with a blue egg inside"),
45
+ gr.Slider(minimum=1, maximum=10, value=4, step=1),
46
+ gr.Slider(minimum=5, maximum=50, value=10, step=1),
47
+ "checkbox"],
48
+ outputs=[gr.HTML(), gr.Image(shape=[512, 512], elem_id="output_image").style(width=512, height=512)],
49
+ # css=".output_image {height: 10% !important; width: 10% !important;}",
50
+ allow_flagging="never"
51
+ )
52
+ demo.queue()
53
+ demo.launch(share=True, server_name="132.68.39.164", server_port=7861)
requirements.txt ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==0.19.0
3
+ aiofiles==23.1.0
4
+ aiohttp==3.7.4.post0
5
+ altair==5.0.1
6
+ anyio==3.7.0
7
+ argon2-cffi==21.3.0
8
+ argon2-cffi-bindings==21.2.0
9
+ asttokens==2.2.1
10
+ astunparse==1.6.3
11
+ async-lru==2.0.2
12
+ async-timeout==3.0.1
13
+ attrs==23.1.0
14
+ Babel==2.12.1
15
+ backcall==0.2.0
16
+ backports.functools-lru-cache==1.6.4
17
+ beautifulsoup4==4.12.2
18
+ bleach==6.0.0
19
+ blinker==1.6.2
20
+ boltons==23.0.0
21
+ cached-property==1.5.2
22
+ cachetools==5.3.0
23
+ certifi==2023.5.7
24
+ cffi==1.15.1
25
+ chardet==4.0.0
26
+ charset-normalizer==3.1.0
27
+ click==8.1.3
28
+ colorama==0.4.6
29
+ comm==0.1.3
30
+ conda==23.3.1
31
+ conda-package-handling==2.0.2
32
+ conda_package_streaming==0.8.0
33
+ contourpy==1.0.7
34
+ cryptography==41.0.0
35
+ cycler==0.11.0
36
+ debugpy==1.6.7
37
+ decorator==5.1.1
38
+ defusedxml==0.7.1
39
+ diffusers==0.16.1
40
+ entrypoints==0.4
41
+ exceptiongroup==1.1.1
42
+ executing==1.2.0
43
+ fastapi==0.96.0
44
+ fastjsonschema==2.17.1
45
+ ffmpy==0.3.0
46
+ filelock==3.12.0
47
+ flatbuffers==23.5.26
48
+ flit_core==3.9.0
49
+ fonttools==4.39.4
50
+ fsspec==2023.5.0
51
+ gast==0.4.0
52
+ gdown==4.7.1
53
+ gmpy2==2.1.2
54
+ google-auth==2.17.3
55
+ google-auth-oauthlib==0.4.6
56
+ google-pasta==0.2.0
57
+ gradio==3.33.1
58
+ gradio_client==0.2.5
59
+ grpcio==1.51.1
60
+ h11==0.14.0
61
+ h5py==3.8.0
62
+ httpcore==0.17.2
63
+ httpx==0.24.1
64
+ huggingface-hub==0.14.1
65
+ idna==3.4
66
+ importlib-metadata==6.6.0
67
+ importlib-resources==5.12.0
68
+ ipykernel==6.23.1
69
+ ipython==8.14.0
70
+ jedi==0.18.2
71
+ Jinja2==3.1.2
72
+ json5==0.9.5
73
+ jsonpatch==1.32
74
+ jsonpointer==2.0
75
+ jsonschema==4.17.3
76
+ jupyter_client==8.2.0
77
+ jupyter_core==5.3.0
78
+ jupyter-events==0.6.3
79
+ jupyter-lsp==2.2.0
80
+ jupyter_server==2.6.0
81
+ jupyter_server_terminals==0.4.4
82
+ jupyterlab==4.0.1
83
+ jupyterlab-pygments==0.2.2
84
+ jupyterlab_server==2.22.1
85
+ keras==2.11.0
86
+ Keras-Preprocessing==1.1.2
87
+ kiwisolver==1.4.4
88
+ libmambapy==1.4.2
89
+ linkify-it-py==2.0.2
90
+ mamba==1.4.2
91
+ Markdown==3.4.3
92
+ markdown-it-py==2.2.0
93
+ MarkupSafe==2.1.2
94
+ matplotlib==3.7.1
95
+ matplotlib-inline==0.1.6
96
+ mdit-py-plugins==0.3.3
97
+ mdurl==0.1.2
98
+ mistune==2.0.5
99
+ mpmath==1.3.0
100
+ multidict==6.0.4
101
+ munkres==1.1.4
102
+ nbclient==0.8.0
103
+ nbconvert==7.4.0
104
+ nbformat==5.9.0
105
+ nest-asyncio==1.5.6
106
+ networkx==3.1
107
+ notebook_shim==0.2.3
108
+ numpy==1.24.3
109
+ oauthlib==3.2.2
110
+ opt-einsum==3.3.0
111
+ orjson==3.9.0
112
+ overrides==7.3.1
113
+ packaging==23.1
114
+ pandas==2.0.2
115
+ pandocfilters==1.5.0
116
+ parso==0.8.3
117
+ pexpect==4.8.0
118
+ pickleshare==0.7.5
119
+ Pillow==9.4.0
120
+ pip==23.1.2
121
+ pkgutil_resolve_name==1.3.10
122
+ platformdirs==3.5.1
123
+ pluggy==1.0.0
124
+ ply==3.11
125
+ pooch==1.7.0
126
+ prometheus-client==0.17.0
127
+ prompt-toolkit==3.0.38
128
+ protobuf==4.21.12
129
+ psutil==5.9.5
130
+ ptyprocess==0.7.0
131
+ pure-eval==0.2.2
132
+ pyasn1==0.4.8
133
+ pyasn1-modules==0.2.7
134
+ pycosat==0.6.4
135
+ pycparser==2.21
136
+ pydantic==1.10.8
137
+ pydub==0.25.1
138
+ Pygments==2.15.1
139
+ PyJWT==2.7.0
140
+ pyOpenSSL==23.2.0
141
+ pyparsing==3.0.9
142
+ PyQt5==5.15.7
143
+ PyQt5-sip==12.11.0
144
+ pyrsistent==0.19.3
145
+ PySocks==1.7.1
146
+ python-dateutil==2.8.2
147
+ python-json-logger==2.0.7
148
+ python-multipart==0.0.6
149
+ pytz==2023.3
150
+ pyu2f==0.1.5
151
+ PyYAML==6.0
152
+ pyzmq==25.1.0
153
+ regex==2023.5.5
154
+ requests==2.31.0
155
+ requests-oauthlib==1.3.1
156
+ rfc3339-validator==0.1.4
157
+ rfc3986-validator==0.1.1
158
+ rsa==4.9
159
+ ruamel.yaml==0.17.31
160
+ ruamel.yaml.clib==0.2.7
161
+ safetensors==0.3.1
162
+ scipy==1.10.1
163
+ semantic-version==2.10.0
164
+ Send2Trash==1.8.2
165
+ setuptools==67.7.2
166
+ sip==6.7.9
167
+ six==1.16.0
168
+ sniffio==1.3.0
169
+ soupsieve==2.3.2.post1
170
+ stack-data==0.6.2
171
+ starlette==0.27.0
172
+ sympy==1.12
173
+ tensorboard==2.11.2
174
+ tensorboard-data-server==0.6.1
175
+ tensorboard-plugin-wit==1.8.1
176
+ tensorboardX==2.5
177
+ tensorflow==2.11.0
178
+ tensorflow-estimator==2.11.0
179
+ termcolor==2.3.0
180
+ terminado==0.17.1
181
+ timm==0.9.2
182
+ tinycss2==1.2.1
183
+ tokenizers==0.13.3
184
+ toml==0.10.2
185
+ tomli==2.0.1
186
+ toolz==0.12.0
187
+ torch==2.0.1
188
+ torchaudio==2.0.2
189
+ torchvision==0.15.2
190
+ tornado==6.3.2
191
+ tqdm==4.65.0
192
+ traitlets==5.9.0
193
+ transformers==4.29.2
194
+ triton==2.0.0
195
+ typing_extensions==4.6.2
196
+ typing-utils==0.1.0
197
+ tzdata==2023.3
198
+ uc-micro-py==1.0.2
199
+ unicodedata2==15.0.0
200
+ urllib3==2.0.2
201
+ uvicorn==0.22.0
202
+ wcwidth==0.2.6
203
+ webencodings==0.5.1
204
+ websocket-client==1.5.2
205
+ websockets==11.0.3
206
+ Werkzeug==2.3.4
207
+ wheel==0.40.0
208
+ wrapt==1.15.0
209
+ yarl==1.9.2
210
+ zipp==3.15.0
211
+ zstandard==0.19.0