sayakpaul HF Staff commited on
Commit
52abadf
·
verified ·
1 Parent(s): fc56851

Upload before_denoise.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. before_denoise.py +218 -0
before_denoise.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from diffusers.modular_pipelines import (
17
+ ModularPipelineBlocks,
18
+ ComponentSpec,
19
+ PipelineState,
20
+ ModularPipeline,
21
+ OutputParam,
22
+ InputParam,
23
+ )
24
+ from diffusers.modular_pipelines.wan.before_denoise import retrieve_timesteps
25
+ from typing import Optional, List, Union, Tuple
26
+ from diffusers.image_processor import PipelineImageInput
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ import torch
29
+ from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
30
+
31
+ # One needs Wan anyway to run ChronoEdit (`AutoencoderKLWan`).
32
+ from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents
33
+
34
+
35
+ class ChronoEditSetTimestepsStep(ModularPipelineBlocks):
36
+ model_name = "chronoedit"
37
+
38
+ @property
39
+ def expected_components(self) -> List[ComponentSpec]:
40
+ return [
41
+ ComponentSpec("scheduler", UniPCMultistepScheduler)
42
+ ]
43
+
44
+ @property
45
+ def inputs(self) -> List[InputParam]:
46
+ return [
47
+ InputParam("num_inference_steps", default=50),
48
+ InputParam("timesteps"),
49
+ InputParam("sigmas")
50
+ ]
51
+
52
+ @property
53
+ def intermediate_outputs(self) -> List[OutputParam]:
54
+ return [
55
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
56
+ OutputParam(
57
+ "num_inference_steps",
58
+ type_hint=int,
59
+ description="The number of denoising steps to perform at inference time",
60
+ ),
61
+ ]
62
+
63
+ @torch.no_grad()
64
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
65
+ block_state = self.get_block_state(state)
66
+ block_state.device = components._execution_device
67
+
68
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
69
+ components.scheduler,
70
+ block_state.num_inference_steps,
71
+ block_state.device,
72
+ block_state.timesteps,
73
+ block_state.sigmas,
74
+ )
75
+
76
+ self.set_block_state(state, block_state)
77
+ return components, state
78
+
79
+
80
+ class ChronoEditPrepareLatentStep(ModularPipelineBlocks):
81
+ model_name = "chronoedit"
82
+
83
+ @property
84
+ def expected_components(self) -> List[ComponentSpec]:
85
+ return [ComponentSpec("vae", AutoencoderKLWan)]
86
+
87
+ @property
88
+ def inputs(self) -> List[InputParam]:
89
+ return [
90
+ InputParam("processed_image", type_hint=PipelineImageInput),
91
+ InputParam("image_embeds", type_hint=torch.Tensor),
92
+ InputParam("height", type_hint=int, default=480),
93
+ InputParam("width", type_hint=int, default=832),
94
+ InputParam("num_frames", type_hint=int, default=81),
95
+ InputParam("batch_size"),
96
+ InputParam("num_videos_per_prompt", type_hint=int, default=1),
97
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
98
+ InputParam("generator"),
99
+ ]
100
+
101
+ @property
102
+ def intermediate_outputs(self) -> List[OutputParam]:
103
+ return [
104
+ OutputParam(
105
+ "latents",
106
+ type_hint=torch.Tensor,
107
+ description="The initial latents to use for the denoising process.",
108
+ ),
109
+ OutputParam(
110
+ "condition",
111
+ type_hint=torch.Tensor,
112
+ description="Conditioning latents for the denoising process.",
113
+ ),
114
+ ]
115
+
116
+ @staticmethod
117
+ def check_inputs(height, width):
118
+ if height % 16 != 0 or width % 16 != 0:
119
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
120
+
121
+ @staticmethod
122
+ def prepare_latents(
123
+ components,
124
+ image: PipelineImageInput,
125
+ batch_size: int,
126
+ num_channels_latents: int = 16,
127
+ height: int = 480,
128
+ width: int = 832,
129
+ num_frames: int = 81,
130
+ dtype: Optional[torch.dtype] = None,
131
+ device: Optional[torch.device] = None,
132
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
133
+ latents: Optional[torch.Tensor] = None,
134
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
135
+ num_latent_frames = (num_frames - 1) // components.vae_scale_factor_temporal + 1
136
+ latent_height = height // components.vae_scale_factor_spatial
137
+ latent_width = width // components.vae_scale_factor_spatial
138
+
139
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
140
+ if isinstance(generator, list) and len(generator) != batch_size:
141
+ raise ValueError(
142
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
143
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
144
+ )
145
+
146
+ if latents is None:
147
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
148
+ else:
149
+ latents = latents.to(device=device, dtype=dtype)
150
+
151
+ image = image.unsqueeze(2)
152
+ video_condition = torch.cat(
153
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
154
+ )
155
+ video_condition = video_condition.to(device=device, dtype=dtype)
156
+
157
+ latents_mean = (
158
+ torch.tensor(components.vae.config.latents_mean)
159
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
160
+ .to(latents.device, latents.dtype)
161
+ )
162
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
163
+ 1, components.vae.config.z_dim, 1, 1, 1
164
+ ).to(latents.device, latents.dtype)
165
+
166
+ if isinstance(generator, list):
167
+ latent_condition = [
168
+ retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator
169
+ ]
170
+ latent_condition = torch.cat(latent_condition)
171
+ else:
172
+ latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax")
173
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
174
+
175
+ latent_condition = (latent_condition - latents_mean) * latents_std
176
+
177
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
178
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
179
+ first_frame_mask = mask_lat_size[:, :, 0:1]
180
+ first_frame_mask = torch.repeat_interleave(
181
+ first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
182
+ )
183
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
184
+ mask_lat_size = mask_lat_size.view(
185
+ batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
186
+ )
187
+ mask_lat_size = mask_lat_size.transpose(1, 2)
188
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
189
+
190
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
191
+
192
+ @torch.no_grad()
193
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
194
+ block_state = self.get_block_state(state)
195
+
196
+ self.check_inputs(block_state.height, block_state.width)
197
+
198
+ block_state.device = components._execution_device
199
+ block_state.num_channels_latents = components.num_channels_latents
200
+
201
+ batch_size = block_state.batch_size * block_state.num_videos_per_prompt
202
+ block_state.latents, block_state.condition = self.prepare_latents(
203
+ components,
204
+ block_state.processed_image,
205
+ batch_size,
206
+ block_state.num_channels_latents,
207
+ block_state.height,
208
+ block_state.width,
209
+ block_state.num_frames,
210
+ torch.bfloat16,
211
+ block_state.device,
212
+ block_state.generator,
213
+ block_state.latents,
214
+ )
215
+
216
+ self.set_block_state(state, block_state)
217
+
218
+ return components, state