smoothieAI
commited on
Commit
•
941962f
1
Parent(s):
be3d287
Update pipeline.py
Browse files- pipeline.py +23 -20
pipeline.py
CHANGED
@@ -199,8 +199,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
199 |
):
|
200 |
super().__init__()
|
201 |
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
|
|
|
|
|
|
202 |
|
203 |
-
if
|
204 |
self.register_modules(
|
205 |
vae=vae,
|
206 |
text_encoder=text_encoder,
|
@@ -218,7 +221,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
218 |
tokenizer=tokenizer,
|
219 |
unet=unet,
|
220 |
motion_adapter=motion_adapter,
|
221 |
-
controlnet=
|
222 |
scheduler=scheduler,
|
223 |
feature_extractor=feature_extractor,
|
224 |
image_encoder=image_encoder,
|
@@ -1117,8 +1120,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1117 |
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
1118 |
"""
|
1119 |
|
1120 |
-
if self.
|
1121 |
-
|
1122 |
|
1123 |
# align format for control guidance
|
1124 |
control_end = control_guidance_end
|
@@ -1127,7 +1130,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1127 |
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
1128 |
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
1129 |
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
1130 |
-
mult = len(
|
1131 |
control_guidance_start, control_guidance_end = (
|
1132 |
mult * [control_guidance_start],
|
1133 |
mult * [control_guidance_end],
|
@@ -1155,14 +1158,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1155 |
|
1156 |
device = self._execution_device
|
1157 |
|
1158 |
-
if self.
|
1159 |
-
if isinstance(
|
1160 |
-
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
|
1161 |
|
1162 |
global_pool_conditions = (
|
1163 |
-
|
1164 |
-
if isinstance(
|
1165 |
-
else
|
1166 |
)
|
1167 |
guess_mode = guess_mode or global_pool_conditions
|
1168 |
|
@@ -1201,8 +1204,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1201 |
if do_classifier_free_guidance:
|
1202 |
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
1203 |
|
1204 |
-
if self.
|
1205 |
-
if isinstance(
|
1206 |
# conditioning_frames = self.prepare_image(
|
1207 |
# image=conditioning_frames,
|
1208 |
# width=width,
|
@@ -1221,12 +1224,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1221 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1222 |
num_images_per_prompt=num_videos_per_prompt,
|
1223 |
device=device,
|
1224 |
-
dtype=
|
1225 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1226 |
guess_mode=guess_mode,
|
1227 |
)
|
1228 |
|
1229 |
-
elif isinstance(
|
1230 |
cond_prepared_frames = []
|
1231 |
for frame_ in conditioning_frames:
|
1232 |
# prepared_frame = self.prepare_image(
|
@@ -1248,7 +1251,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1248 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1249 |
num_images_per_prompt=num_videos_per_prompt,
|
1250 |
device=device,
|
1251 |
-
dtype=
|
1252 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1253 |
guess_mode=guess_mode,
|
1254 |
)
|
@@ -1367,14 +1370,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1367 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1368 |
|
1369 |
# 7.1 Create tensor stating which controlnets to keep
|
1370 |
-
if self.
|
1371 |
controlnet_keep = []
|
1372 |
for i in range(len(timesteps)):
|
1373 |
keeps = [
|
1374 |
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1375 |
for s, e in zip(control_guidance_start, control_guidance_end)
|
1376 |
]
|
1377 |
-
controlnet_keep.append(keeps[0] if isinstance(
|
1378 |
|
1379 |
# divide the initial latents into context groups
|
1380 |
|
@@ -1431,7 +1434,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1431 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1432 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1433 |
|
1434 |
-
if self.
|
1435 |
|
1436 |
torch.cuda.synchronize() # Synchronize GPU
|
1437 |
control_start = time.time()
|
@@ -1465,7 +1468,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1465 |
)
|
1466 |
|
1467 |
|
1468 |
-
down_block_res_samples, mid_block_res_sample = self.
|
1469 |
control_model_input,
|
1470 |
t,
|
1471 |
encoder_hidden_states=controlnet_prompt_embeds,
|
|
|
199 |
):
|
200 |
super().__init__()
|
201 |
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
202 |
+
|
203 |
+
# temp workaround to prevent ip adapter library from loading ip adapter on empty controlnet parameter
|
204 |
+
controlnet = controlnets
|
205 |
|
206 |
+
if controlnet is None:
|
207 |
self.register_modules(
|
208 |
vae=vae,
|
209 |
text_encoder=text_encoder,
|
|
|
221 |
tokenizer=tokenizer,
|
222 |
unet=unet,
|
223 |
motion_adapter=motion_adapter,
|
224 |
+
controlnet=controlnet,
|
225 |
scheduler=scheduler,
|
226 |
feature_extractor=feature_extractor,
|
227 |
image_encoder=image_encoder,
|
|
|
1120 |
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
1121 |
"""
|
1122 |
|
1123 |
+
if self.controlnet != None:
|
1124 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
1125 |
|
1126 |
# align format for control guidance
|
1127 |
control_end = control_guidance_end
|
|
|
1130 |
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
1131 |
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
1132 |
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
1133 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
1134 |
control_guidance_start, control_guidance_end = (
|
1135 |
mult * [control_guidance_start],
|
1136 |
mult * [control_guidance_end],
|
|
|
1158 |
|
1159 |
device = self._execution_device
|
1160 |
|
1161 |
+
if self.controlnet != None:
|
1162 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
1163 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
1164 |
|
1165 |
global_pool_conditions = (
|
1166 |
+
controlnet.config.global_pool_conditions
|
1167 |
+
if isinstance(controlnet, ControlNetModel)
|
1168 |
+
else controlnet.nets[0].config.global_pool_conditions
|
1169 |
)
|
1170 |
guess_mode = guess_mode or global_pool_conditions
|
1171 |
|
|
|
1204 |
if do_classifier_free_guidance:
|
1205 |
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
1206 |
|
1207 |
+
if self.controlnet != None:
|
1208 |
+
if isinstance(controlnet, ControlNetModel):
|
1209 |
# conditioning_frames = self.prepare_image(
|
1210 |
# image=conditioning_frames,
|
1211 |
# width=width,
|
|
|
1224 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1225 |
num_images_per_prompt=num_videos_per_prompt,
|
1226 |
device=device,
|
1227 |
+
dtype=controlnet.dtype,
|
1228 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1229 |
guess_mode=guess_mode,
|
1230 |
)
|
1231 |
|
1232 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
1233 |
cond_prepared_frames = []
|
1234 |
for frame_ in conditioning_frames:
|
1235 |
# prepared_frame = self.prepare_image(
|
|
|
1251 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1252 |
num_images_per_prompt=num_videos_per_prompt,
|
1253 |
device=device,
|
1254 |
+
dtype=controlnet.dtype,
|
1255 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1256 |
guess_mode=guess_mode,
|
1257 |
)
|
|
|
1370 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1371 |
|
1372 |
# 7.1 Create tensor stating which controlnets to keep
|
1373 |
+
if self.controlnet != None:
|
1374 |
controlnet_keep = []
|
1375 |
for i in range(len(timesteps)):
|
1376 |
keeps = [
|
1377 |
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1378 |
for s, e in zip(control_guidance_start, control_guidance_end)
|
1379 |
]
|
1380 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
1381 |
|
1382 |
# divide the initial latents into context groups
|
1383 |
|
|
|
1434 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1435 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1436 |
|
1437 |
+
if self.controlnet != None and i < int(control_end*num_inference_steps):
|
1438 |
|
1439 |
torch.cuda.synchronize() # Synchronize GPU
|
1440 |
control_start = time.time()
|
|
|
1468 |
)
|
1469 |
|
1470 |
|
1471 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1472 |
control_model_input,
|
1473 |
t,
|
1474 |
encoder_hidden_states=controlnet_prompt_embeds,
|