Spaces:
Paused
Paused
import copy | |
import os | |
import torch | |
from pathlib import Path | |
from modules import devices | |
from scripts.adapter import PlugableAdapter, Adapter, StyleAdapter, Adapter_light | |
from scripts.controlnet_lllite import PlugableControlLLLite | |
from scripts.cldm import PlugableControlModel | |
from scripts.controlmodel_ipadapter import PlugableIPAdapter | |
from scripts.logging import logger | |
from scripts.controlnet_diffusers import convert_from_diffuser_state_dict | |
from scripts.controlnet_lora import controlnet_lora_hijack, force_load_state_dict | |
controlnet_default_config = {'adm_in_channels': None, | |
'in_channels': 4, | |
'model_channels': 320, | |
'num_res_blocks': 2, | |
'attention_resolutions': [1, 2, 4], | |
'transformer_depth': [1, 1, 1, 0], | |
'channel_mult': [1, 2, 4, 4], | |
'transformer_depth_middle': 1, | |
'use_linear_in_transformer': False, | |
'context_dim': 768, | |
"num_heads": 8, | |
"global_average_pooling": False} | |
controlnet_sdxl_config = {'num_classes': 'sequential', | |
'adm_in_channels': 2816, | |
'in_channels': 4, | |
'model_channels': 320, | |
'num_res_blocks': 2, | |
'attention_resolutions': [2, 4], | |
'transformer_depth': [0, 2, 10], | |
'channel_mult': [1, 2, 4], | |
'transformer_depth_middle': 10, | |
'use_linear_in_transformer': True, | |
'context_dim': 2048, | |
"num_head_channels": 64, | |
"global_average_pooling": False} | |
controlnet_sdxl_mid_config = {'num_classes': 'sequential', | |
'adm_in_channels': 2816, | |
'in_channels': 4, | |
'model_channels': 320, | |
'num_res_blocks': 2, | |
'attention_resolutions': [4], | |
'transformer_depth': [0, 0, 1], | |
'channel_mult': [1, 2, 4], | |
'transformer_depth_middle': 1, | |
'use_linear_in_transformer': True, | |
'context_dim': 2048, | |
"num_head_channels": 64, | |
"global_average_pooling": False} | |
controlnet_sdxl_small_config = {'num_classes': 'sequential', | |
'adm_in_channels': 2816, | |
'in_channels': 4, | |
'model_channels': 320, | |
'num_res_blocks': 2, | |
'attention_resolutions': [], | |
'transformer_depth': [0, 0, 0], | |
'channel_mult': [1, 2, 4], | |
'transformer_depth_middle': 0, | |
'use_linear_in_transformer': True, | |
"num_head_channels": 64, | |
'context_dim': 1, | |
"global_average_pooling": False} | |
t2i_adapter_config = { | |
'channels': [320, 640, 1280, 1280], | |
'nums_rb': 2, | |
'ksize': 1, | |
'sk': True, | |
'cin': 192, | |
'use_conv': False | |
} | |
t2i_adapter_light_config = { | |
'channels': [320, 640, 1280, 1280], | |
'nums_rb': 4, | |
'cin': 192, | |
} | |
t2i_adapter_style_config = { | |
'width': 1024, | |
'context_dim': 768, | |
'num_head': 8, | |
'n_layes': 3, | |
'num_token': 8, | |
} | |
# Stolen from https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/utils.py | |
def state_dict_key_replace(state_dict, keys_to_replace): | |
for x in keys_to_replace: | |
if x in state_dict: | |
state_dict[keys_to_replace[x]] = state_dict.pop(x) | |
return state_dict | |
# # Stolen from https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/utils.py | |
def state_dict_prefix_replace(state_dict, replace_prefix): | |
for rp in replace_prefix: | |
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) | |
for x in replace: | |
state_dict[x[1]] = state_dict.pop(x[0]) | |
return state_dict | |
def build_model_by_guess(state_dict, unet, model_path): | |
if "lora_controlnet" in state_dict: | |
del state_dict['lora_controlnet'] | |
config = copy.deepcopy(controlnet_sdxl_config) | |
logger.info('controlnet_sdxl_config (using lora)') | |
config['global_average_pooling'] = False | |
config['hint_channels'] = int(state_dict['input_hint_block.0.weight'].shape[1]) | |
config['use_fp16'] = devices.dtype_unet == torch.float16 | |
with controlnet_lora_hijack(): | |
network = PlugableControlModel(config, state_dict=None) | |
force_load_state_dict(network.control_model, state_dict) | |
network.is_control_lora = True | |
network.to(devices.dtype_unet) | |
return network | |
if "controlnet_cond_embedding.conv_in.weight" in state_dict: # diffusers | |
state_dict = convert_from_diffuser_state_dict(state_dict) | |
if 'adapter.body.0.resnets.0.block1.weight' in state_dict: # diffusers | |
prefix_replace = {} | |
for i in range(4): | |
for j in range(2): | |
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j) | |
prefix_replace["adapter.body.{}.".format(i)] = "body.{}.".format(i * 2) | |
prefix_replace["adapter."] = "" | |
state_dict = state_dict_prefix_replace(state_dict, prefix_replace) | |
model_has_shuffle_in_filename = 'shuffle' in Path(os.path.abspath(model_path)).stem.lower() | |
state_dict = {k.replace("control_model.", ""): v for k, v in state_dict.items()} | |
state_dict = {k.replace("adapter.", ""): v for k, v in state_dict.items()} | |
if 'input_hint_block.0.weight' in state_dict: | |
if 'label_emb.0.0.bias' not in state_dict: | |
config = copy.deepcopy(controlnet_default_config) | |
logger.info('controlnet_default_config') | |
config['global_average_pooling'] = model_has_shuffle_in_filename | |
config['hint_channels'] = int(state_dict['input_hint_block.0.weight'].shape[1]) | |
config['context_dim'] = int(state_dict['input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]) | |
for key in state_dict.keys(): | |
p = state_dict[key] | |
if 'proj_in.weight' in key or 'proj_out.weight' in key: | |
if len(p.shape) == 2: | |
p = p[..., None, None] | |
state_dict[key] = p | |
else: | |
has_full_layers = 'input_blocks.8.1.transformer_blocks.9.norm3.weight' in state_dict | |
has_mid_layers = 'input_blocks.8.1.transformer_blocks.0.norm3.weight' in state_dict | |
if has_full_layers: | |
config = copy.deepcopy(controlnet_sdxl_config) | |
logger.info('controlnet_sdxl_config') | |
elif has_mid_layers: | |
config = copy.deepcopy(controlnet_sdxl_mid_config) | |
logger.info('controlnet_sdxl_mid_config') | |
else: | |
config = copy.deepcopy(controlnet_sdxl_small_config) | |
logger.info('controlnet_sdxl_small_config') | |
config['global_average_pooling'] = False | |
config['hint_channels'] = int(state_dict['input_hint_block.0.weight'].shape[1]) | |
if 'difference' in state_dict and unet is not None: | |
unet_state_dict = unet.state_dict() | |
unet_state_dict_keys = unet_state_dict.keys() | |
final_state_dict = {} | |
for key in state_dict.keys(): | |
p = state_dict[key] | |
if key in unet_state_dict_keys: | |
p_new = p + unet_state_dict[key].clone().cpu() | |
else: | |
p_new = p | |
final_state_dict[key] = p_new | |
state_dict = final_state_dict | |
config['use_fp16'] = devices.dtype_unet == torch.float16 | |
network = PlugableControlModel(config, state_dict) | |
network.to(devices.dtype_unet) | |
return network | |
if 'conv_in.weight' in state_dict: | |
logger.info('t2i_adapter_config') | |
cin = int(state_dict['conv_in.weight'].shape[1]) | |
channel = int(state_dict['conv_in.weight'].shape[0]) | |
ksize = int(state_dict['body.0.block2.weight'].shape[2]) | |
down_opts = tuple(filter(lambda item: item.endswith("down_opt.op.weight"), state_dict)) | |
use_conv = len(down_opts) > 0 | |
is_sdxl = cin == 256 or cin == 768 | |
adapter = Adapter( | |
cin=cin, | |
channels=[channel, channel*2, channel*4, channel*4], | |
nums_rb=2, | |
ksize=ksize, | |
sk=True, | |
use_conv=use_conv, | |
is_sdxl=is_sdxl | |
).cpu() | |
adapter.load_state_dict(state_dict, strict=False) | |
network = PlugableAdapter(adapter) | |
return network | |
if 'style_embedding' in state_dict: | |
config = copy.deepcopy(t2i_adapter_style_config) | |
logger.info('t2i_adapter_style_config') | |
adapter = StyleAdapter(**config).cpu() | |
adapter.load_state_dict(state_dict, strict=False) | |
network = PlugableAdapter(adapter) | |
return network | |
if 'body.0.in_conv.weight' in state_dict: | |
config = copy.deepcopy(t2i_adapter_light_config) | |
logger.info('t2i_adapter_light_config') | |
config['cin'] = int(state_dict['body.0.in_conv.weight'].shape[1]) | |
adapter = Adapter_light(**config).cpu() | |
adapter.load_state_dict(state_dict, strict=False) | |
network = PlugableAdapter(adapter) | |
return network | |
if 'ip_adapter' in state_dict: | |
plus = "latents" in state_dict["image_proj"] | |
if plus: | |
channel = int(state_dict['image_proj']['proj_in.weight'].shape[1]) | |
else: | |
channel = int(state_dict['image_proj']['proj.weight'].shape[1]) | |
network = PlugableIPAdapter(state_dict, channel, plus) | |
network.to('cpu') | |
return network | |
if any('lllite' in k for k in state_dict.keys()): | |
network = PlugableControlLLLite(state_dict) | |
network.to('cpu') | |
return network | |
raise '[ControlNet Error] Cannot recognize the ControlModel!' | |