smoothieAI commited on
Commit
941962f
1 Parent(s): be3d287

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +23 -20
pipeline.py CHANGED
@@ -199,8 +199,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
199
  ):
200
  super().__init__()
201
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
 
 
 
202
 
203
- if controlnets is None:
204
  self.register_modules(
205
  vae=vae,
206
  text_encoder=text_encoder,
@@ -218,7 +221,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
218
  tokenizer=tokenizer,
219
  unet=unet,
220
  motion_adapter=motion_adapter,
221
- controlnet=controlnets,
222
  scheduler=scheduler,
223
  feature_extractor=feature_extractor,
224
  image_encoder=image_encoder,
@@ -1117,8 +1120,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1117
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1118
  """
1119
 
1120
- if self.controlnets != None:
1121
- controlnets = self.controlnets._orig_mod if is_compiled_module(self.controlnets) else self.controlnets
1122
 
1123
  # align format for control guidance
1124
  control_end = control_guidance_end
@@ -1127,7 +1130,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1127
  elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1128
  control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1129
  elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1130
- mult = len(controlnets.nets) if isinstance(controlnets, MultiControlNetModel) else 1
1131
  control_guidance_start, control_guidance_end = (
1132
  mult * [control_guidance_start],
1133
  mult * [control_guidance_end],
@@ -1155,14 +1158,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1155
 
1156
  device = self._execution_device
1157
 
1158
- if self.controlnets != None:
1159
- if isinstance(controlnets, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1160
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnets.nets)
1161
 
1162
  global_pool_conditions = (
1163
- controlnets.config.global_pool_conditions
1164
- if isinstance(controlnets, ControlNetModel)
1165
- else controlnets.nets[0].config.global_pool_conditions
1166
  )
1167
  guess_mode = guess_mode or global_pool_conditions
1168
 
@@ -1201,8 +1204,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1201
  if do_classifier_free_guidance:
1202
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
1203
 
1204
- if self.controlnets != None:
1205
- if isinstance(controlnets, ControlNetModel):
1206
  # conditioning_frames = self.prepare_image(
1207
  # image=conditioning_frames,
1208
  # width=width,
@@ -1221,12 +1224,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1221
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1222
  num_images_per_prompt=num_videos_per_prompt,
1223
  device=device,
1224
- dtype=controlnets.dtype,
1225
  do_classifier_free_guidance=do_classifier_free_guidance,
1226
  guess_mode=guess_mode,
1227
  )
1228
 
1229
- elif isinstance(controlnets, MultiControlNetModel):
1230
  cond_prepared_frames = []
1231
  for frame_ in conditioning_frames:
1232
  # prepared_frame = self.prepare_image(
@@ -1248,7 +1251,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1248
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1249
  num_images_per_prompt=num_videos_per_prompt,
1250
  device=device,
1251
- dtype=controlnets.dtype,
1252
  do_classifier_free_guidance=do_classifier_free_guidance,
1253
  guess_mode=guess_mode,
1254
  )
@@ -1367,14 +1370,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1367
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1368
 
1369
  # 7.1 Create tensor stating which controlnets to keep
1370
- if self.controlnets != None:
1371
  controlnet_keep = []
1372
  for i in range(len(timesteps)):
1373
  keeps = [
1374
  1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1375
  for s, e in zip(control_guidance_start, control_guidance_end)
1376
  ]
1377
- controlnet_keep.append(keeps[0] if isinstance(controlnets, ControlNetModel) else keeps)
1378
 
1379
  # divide the initial latents into context groups
1380
 
@@ -1431,7 +1434,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1431
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1432
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1433
 
1434
- if self.controlnets != None and i < int(control_end*num_inference_steps):
1435
 
1436
  torch.cuda.synchronize() # Synchronize GPU
1437
  control_start = time.time()
@@ -1465,7 +1468,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1465
  )
1466
 
1467
 
1468
- down_block_res_samples, mid_block_res_sample = self.controlnets(
1469
  control_model_input,
1470
  t,
1471
  encoder_hidden_states=controlnet_prompt_embeds,
 
199
  ):
200
  super().__init__()
201
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
202
+
203
+ # temp workaround to prevent ip adapter library from loading ip adapter on empty controlnet parameter
204
+ controlnet = controlnets
205
 
206
+ if controlnet is None:
207
  self.register_modules(
208
  vae=vae,
209
  text_encoder=text_encoder,
 
221
  tokenizer=tokenizer,
222
  unet=unet,
223
  motion_adapter=motion_adapter,
224
+ controlnet=controlnet,
225
  scheduler=scheduler,
226
  feature_extractor=feature_extractor,
227
  image_encoder=image_encoder,
 
1120
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1121
  """
1122
 
1123
+ if self.controlnet != None:
1124
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1125
 
1126
  # align format for control guidance
1127
  control_end = control_guidance_end
 
1130
  elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1131
  control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1132
  elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1133
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1134
  control_guidance_start, control_guidance_end = (
1135
  mult * [control_guidance_start],
1136
  mult * [control_guidance_end],
 
1158
 
1159
  device = self._execution_device
1160
 
1161
+ if self.controlnet != None:
1162
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1163
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1164
 
1165
  global_pool_conditions = (
1166
+ controlnet.config.global_pool_conditions
1167
+ if isinstance(controlnet, ControlNetModel)
1168
+ else controlnet.nets[0].config.global_pool_conditions
1169
  )
1170
  guess_mode = guess_mode or global_pool_conditions
1171
 
 
1204
  if do_classifier_free_guidance:
1205
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
1206
 
1207
+ if self.controlnet != None:
1208
+ if isinstance(controlnet, ControlNetModel):
1209
  # conditioning_frames = self.prepare_image(
1210
  # image=conditioning_frames,
1211
  # width=width,
 
1224
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1225
  num_images_per_prompt=num_videos_per_prompt,
1226
  device=device,
1227
+ dtype=controlnet.dtype,
1228
  do_classifier_free_guidance=do_classifier_free_guidance,
1229
  guess_mode=guess_mode,
1230
  )
1231
 
1232
+ elif isinstance(controlnet, MultiControlNetModel):
1233
  cond_prepared_frames = []
1234
  for frame_ in conditioning_frames:
1235
  # prepared_frame = self.prepare_image(
 
1251
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1252
  num_images_per_prompt=num_videos_per_prompt,
1253
  device=device,
1254
+ dtype=controlnet.dtype,
1255
  do_classifier_free_guidance=do_classifier_free_guidance,
1256
  guess_mode=guess_mode,
1257
  )
 
1370
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1371
 
1372
  # 7.1 Create tensor stating which controlnets to keep
1373
+ if self.controlnet != None:
1374
  controlnet_keep = []
1375
  for i in range(len(timesteps)):
1376
  keeps = [
1377
  1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1378
  for s, e in zip(control_guidance_start, control_guidance_end)
1379
  ]
1380
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1381
 
1382
  # divide the initial latents into context groups
1383
 
 
1434
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1435
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1436
 
1437
+ if self.controlnet != None and i < int(control_end*num_inference_steps):
1438
 
1439
  torch.cuda.synchronize() # Synchronize GPU
1440
  control_start = time.time()
 
1468
  )
1469
 
1470
 
1471
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1472
  control_model_input,
1473
  t,
1474
  encoder_hidden_states=controlnet_prompt_embeds,