Spaces:
Paused
Paused
File size: 6,060 Bytes
13f1a87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
import torch
import PIL
import PIL.Image
from diffusers.loaders import UNet2DConditionLoadersMixin
from typing import Dict
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
import functools
from cross_frame_attention import CrossFrameAttnProcessor
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
NEGATIVE_PROMPT = "blurry, text, caption, lowquality, lowresolution, low res, grainy, ugly"
def attach_loaders_mixin(model):
# hacky way to make ControlNet work with LoRA. This may not be required in future versions of diffusers.
model.text_encoder_name = TEXT_ENCODER_NAME
model.unet_name = UNET_NAME
r"""
Attach the [`UNet2DConditionLoadersMixin`] to a model. This will add the
all the methods from the mixin 'UNet2DConditionLoadersMixin' to the model.
"""
# mixin_instance = UNet2DConditionLoadersMixin()
for attr_name, attr_value in vars(UNet2DConditionLoadersMixin).items():
# print(attr_name)
if callable(attr_value):
# setattr(model, attr_name, functools.partialmethod(attr_value, model).__get__(model, model.__class__))
setattr(model, attr_name, functools.partial(attr_value, model))
return model
def set_attn_processor(module, processor, _remove_lora=False):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in module.named_children():
fn_recursive_attn_processor(name, module, processor)
class ControlNetX(ControlNetModel, UNet2DConditionLoadersMixin):
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# This may not be required in future versions of diffusers.
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
class ControlNetPipeline:
def __init__(self, checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
controlnet = ControlNetX.from_pretrained(checkpoint)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
sd_checkpoint, controlnet=controlnet, requires_safety_checker=False, safety_checker=None,
torch_dtype=torch.float16)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
@torch.no_grad()
def __call__(self,
prompt: str="",
height=512,
width=512,
control_image=None,
controlnet_conditioning_scale=1.0,
num_inference_steps: int=20,
**kwargs) -> PIL.Image.Image:
out = self.pipe(prompt, control_image,
height=height, width=width,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
**kwargs).images
return out[0] if len(out) == 1 else out
def to(self, *args, **kwargs):
self.pipe.to(*args, **kwargs)
return self
class LooseControlNet(ControlNetPipeline):
def __init__(self, loose_control_weights="shariqfarooq/loose-control-3dbox", cn_checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
super().__init__(cn_checkpoint, sd_checkpoint)
self.pipe.controlnet = attach_loaders_mixin(self.pipe.controlnet)
self.pipe.controlnet.load_attn_procs(loose_control_weights)
def set_normal_attention(self):
self.pipe.unet.set_attn_processor(AttnProcessor())
def set_cf_attention(self, _remove_lora=False):
for upblocks in self.pipe.unet.up_blocks[-2:]:
set_attn_processor(upblocks, CrossFrameAttnProcessor(), _remove_lora=_remove_lora)
def edit(self, depth, depth_edit, prompt, prompt_edit=None, seed=42, seed_edit=None, negative_prompt=NEGATIVE_PROMPT, controlnet_conditioning_scale=1.0, num_inference_steps=20, **kwargs):
if prompt_edit is None:
prompt_edit = prompt
if seed_edit is None:
seed_edit = seed
seed = int(seed)
seed_edit = int(seed_edit)
control_image = [depth, depth_edit]
prompt = [prompt, prompt_edit]
generator = [torch.Generator().manual_seed(seed), torch.Generator().manual_seed(seed_edit)]
gen = self.pipe(prompt, control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator, num_inference_steps=num_inference_steps, negative_prompt=negative_prompt, **kwargs)[-1]
return gen |