Rename pipeline_imagedream.py to pipeline_mvdiffusion.py
Browse files
pipeline_imagedream.py → pipeline_mvdiffusion.py
RENAMED
@@ -41,7 +41,7 @@ from transformers import (
|
|
41 |
)
|
42 |
|
43 |
|
44 |
-
class
|
45 |
def __init__(
|
46 |
self,
|
47 |
vae: AutoencoderKL,
|
@@ -50,15 +50,15 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
50 |
unet: UNet2DConditionModel,
|
51 |
scheduler: KarrasDiffusionSchedulers,
|
52 |
safety_checker: StableDiffusionSafetyChecker,
|
53 |
-
feature_extractor: CLIPImageProcessor,
|
54 |
-
image_encoder: CLIPVisionModel = None,
|
55 |
requires_safety_checker: bool = False,
|
56 |
) -> None:
|
57 |
super().__init__(
|
58 |
vae=vae,
|
59 |
text_encoder=text_encoder,
|
60 |
tokenizer=tokenizer,
|
61 |
-
unet=
|
62 |
scheduler=scheduler,
|
63 |
safety_checker=safety_checker,
|
64 |
feature_extractor=feature_extractor,
|
@@ -88,7 +88,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
88 |
|
89 |
if weight_name == "ip-adapter-plus_imagedream.bin":
|
90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
91 |
-
|
92 |
set_num_views(self.unet, self.num_views + 1)
|
93 |
|
94 |
def unload_ip_adapter(self) -> None:
|
@@ -193,7 +193,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
193 |
if cross_attention_kwargs is None:
|
194 |
num_views = self.num_views
|
195 |
else:
|
196 |
-
|
197 |
|
198 |
# 0. Default height and width to unet
|
199 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
@@ -506,11 +506,11 @@ def get_camera(
|
|
506 |
# fmt: on
|
507 |
|
508 |
|
509 |
-
def
|
510 |
attn_procs = {}
|
511 |
for key, attn_processor in unet.attn_processors.items():
|
512 |
if "attn1" in key:
|
513 |
-
attn_procs[key] =
|
514 |
else:
|
515 |
attn_procs[key] = attn_processor
|
516 |
unet.set_attn_processor(attn_procs)
|
@@ -519,12 +519,12 @@ def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> UNet2DCondition
|
|
519 |
|
520 |
def set_num_views(unet: UNet2DConditionModel, num_views: int) -> UNet2DConditionModel:
|
521 |
for key, attn_processor in unet.attn_processors.items():
|
522 |
-
if isinstance(attn_processor,
|
523 |
attn_processor.num_views = num_views
|
524 |
return unet
|
525 |
|
526 |
|
527 |
-
class
|
528 |
def __init__(self, num_views: int = 4):
|
529 |
super().__init__()
|
530 |
self.num_views = num_views
|
|
|
41 |
)
|
42 |
|
43 |
|
44 |
+
class MVDiffusionPipeline(StableDiffusionPipeline):
|
45 |
def __init__(
|
46 |
self,
|
47 |
vae: AutoencoderKL,
|
|
|
50 |
unet: UNet2DConditionModel,
|
51 |
scheduler: KarrasDiffusionSchedulers,
|
52 |
safety_checker: StableDiffusionSafetyChecker,
|
53 |
+
feature_extractor: Optional[CLIPImageProcessor] = None,
|
54 |
+
image_encoder: Optional[CLIPVisionModel] = None,
|
55 |
requires_safety_checker: bool = False,
|
56 |
) -> None:
|
57 |
super().__init__(
|
58 |
vae=vae,
|
59 |
text_encoder=text_encoder,
|
60 |
tokenizer=tokenizer,
|
61 |
+
unet=add_mv_attn_processor(unet),
|
62 |
scheduler=scheduler,
|
63 |
safety_checker=safety_checker,
|
64 |
feature_extractor=feature_extractor,
|
|
|
88 |
|
89 |
if weight_name == "ip-adapter-plus_imagedream.bin":
|
90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
91 |
+
add_mv_attn_processor(self.unet)
|
92 |
set_num_views(self.unet, self.num_views + 1)
|
93 |
|
94 |
def unload_ip_adapter(self) -> None:
|
|
|
193 |
if cross_attention_kwargs is None:
|
194 |
num_views = self.num_views
|
195 |
else:
|
196 |
+
cross_attention_kwargs.pop("num_views", self.num_views)
|
197 |
|
198 |
# 0. Default height and width to unet
|
199 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
506 |
# fmt: on
|
507 |
|
508 |
|
509 |
+
def add_mv_attn_processor(unet: UNet2DConditionModel, num_views: int = 4) -> UNet2DConditionModel:
|
510 |
attn_procs = {}
|
511 |
for key, attn_processor in unet.attn_processors.items():
|
512 |
if "attn1" in key:
|
513 |
+
attn_procs[key] = MVAttnProcessor2_0(num_views)
|
514 |
else:
|
515 |
attn_procs[key] = attn_processor
|
516 |
unet.set_attn_processor(attn_procs)
|
|
|
519 |
|
520 |
def set_num_views(unet: UNet2DConditionModel, num_views: int) -> UNet2DConditionModel:
|
521 |
for key, attn_processor in unet.attn_processors.items():
|
522 |
+
if isinstance(attn_processor, MVAttnProcessor2_0):
|
523 |
attn_processor.num_views = num_views
|
524 |
return unet
|
525 |
|
526 |
|
527 |
+
class MVAttnProcessor2_0(AttnProcessor2_0):
|
528 |
def __init__(self, num_views: int = 4):
|
529 |
super().__init__()
|
530 |
self.num_views = num_views
|