kiigii commited on
Commit
84c55d2
1 Parent(s): b460452

Rename pipeline_imagedream.py to pipeline_mvdiffusion.py

Browse files
pipeline_imagedream.py → pipeline_mvdiffusion.py RENAMED
@@ -41,7 +41,7 @@ from transformers import (
41
  )
42
 
43
 
44
- class ImageDreamPipeline(StableDiffusionPipeline):
45
  def __init__(
46
  self,
47
  vae: AutoencoderKL,
@@ -50,15 +50,15 @@ class ImageDreamPipeline(StableDiffusionPipeline):
50
  unet: UNet2DConditionModel,
51
  scheduler: KarrasDiffusionSchedulers,
52
  safety_checker: StableDiffusionSafetyChecker,
53
- feature_extractor: CLIPImageProcessor,
54
- image_encoder: CLIPVisionModel = None,
55
  requires_safety_checker: bool = False,
56
  ) -> None:
57
  super().__init__(
58
  vae=vae,
59
  text_encoder=text_encoder,
60
  tokenizer=tokenizer,
61
- unet=add_imagedream_attn_processor(unet),
62
  scheduler=scheduler,
63
  safety_checker=safety_checker,
64
  feature_extractor=feature_extractor,
@@ -88,7 +88,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
88
 
89
  if weight_name == "ip-adapter-plus_imagedream.bin":
90
  setattr(self.image_encoder, "visual_projection", nn.Identity())
91
- add_imagedream_attn_processor(self.unet)
92
  set_num_views(self.unet, self.num_views + 1)
93
 
94
  def unload_ip_adapter(self) -> None:
@@ -193,7 +193,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
193
  if cross_attention_kwargs is None:
194
  num_views = self.num_views
195
  else:
196
- num_views = cross_attention_kwargs.pop("num_views", self.num_views)
197
 
198
  # 0. Default height and width to unet
199
  height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -506,11 +506,11 @@ def get_camera(
506
  # fmt: on
507
 
508
 
509
- def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> UNet2DConditionModel:
510
  attn_procs = {}
511
  for key, attn_processor in unet.attn_processors.items():
512
  if "attn1" in key:
513
- attn_procs[key] = ImageDreamAttnProcessor2_0()
514
  else:
515
  attn_procs[key] = attn_processor
516
  unet.set_attn_processor(attn_procs)
@@ -519,12 +519,12 @@ def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> UNet2DCondition
519
 
520
  def set_num_views(unet: UNet2DConditionModel, num_views: int) -> UNet2DConditionModel:
521
  for key, attn_processor in unet.attn_processors.items():
522
- if isinstance(attn_processor, ImageDreamAttnProcessor2_0):
523
  attn_processor.num_views = num_views
524
  return unet
525
 
526
 
527
- class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
528
  def __init__(self, num_views: int = 4):
529
  super().__init__()
530
  self.num_views = num_views
 
41
  )
42
 
43
 
44
+ class MVDiffusionPipeline(StableDiffusionPipeline):
45
  def __init__(
46
  self,
47
  vae: AutoencoderKL,
 
50
  unet: UNet2DConditionModel,
51
  scheduler: KarrasDiffusionSchedulers,
52
  safety_checker: StableDiffusionSafetyChecker,
53
+ feature_extractor: Optional[CLIPImageProcessor] = None,
54
+ image_encoder: Optional[CLIPVisionModel] = None,
55
  requires_safety_checker: bool = False,
56
  ) -> None:
57
  super().__init__(
58
  vae=vae,
59
  text_encoder=text_encoder,
60
  tokenizer=tokenizer,
61
+ unet=add_mv_attn_processor(unet),
62
  scheduler=scheduler,
63
  safety_checker=safety_checker,
64
  feature_extractor=feature_extractor,
 
88
 
89
  if weight_name == "ip-adapter-plus_imagedream.bin":
90
  setattr(self.image_encoder, "visual_projection", nn.Identity())
91
+ add_mv_attn_processor(self.unet)
92
  set_num_views(self.unet, self.num_views + 1)
93
 
94
  def unload_ip_adapter(self) -> None:
 
193
  if cross_attention_kwargs is None:
194
  num_views = self.num_views
195
  else:
196
+ cross_attention_kwargs.pop("num_views", self.num_views)
197
 
198
  # 0. Default height and width to unet
199
  height = height or self.unet.config.sample_size * self.vae_scale_factor
 
506
  # fmt: on
507
 
508
 
509
+ def add_mv_attn_processor(unet: UNet2DConditionModel, num_views: int = 4) -> UNet2DConditionModel:
510
  attn_procs = {}
511
  for key, attn_processor in unet.attn_processors.items():
512
  if "attn1" in key:
513
+ attn_procs[key] = MVAttnProcessor2_0(num_views)
514
  else:
515
  attn_procs[key] = attn_processor
516
  unet.set_attn_processor(attn_procs)
 
519
 
520
  def set_num_views(unet: UNet2DConditionModel, num_views: int) -> UNet2DConditionModel:
521
  for key, attn_processor in unet.attn_processors.items():
522
+ if isinstance(attn_processor, MVAttnProcessor2_0):
523
  attn_processor.num_views = num_views
524
  return unet
525
 
526
 
527
+ class MVAttnProcessor2_0(AttnProcessor2_0):
528
  def __init__(self, num_views: int = 4):
529
  super().__init__()
530
  self.num_views = num_views