| import re |
|
|
| SDXL_DEFAULT_CONFIG = [ |
| { |
| "wildcard_or_filter_func": lambda name: "up_blocks.2" not in name, |
| "select_cache_step_func": lambda step: (step % 2) != 0, |
| } |
| ] |
|
|
| PIXART_DEFAULT_CONFIG = [ |
| { |
| "wildcard_or_filter_func": lambda name: not re.search( |
| r"transformer_blocks\.(2[1-7])\.", name |
| ), |
| "select_cache_step_func": lambda step: (step % 3) != 0, |
| } |
| ] |
|
|
| SVD_DEFAULT_CONFIG = [ |
| { |
| "wildcard_or_filter_func": lambda name: "up_blocks.3" not in name, |
| "select_cache_step_func": lambda step: (step % 2) != 0, |
| } |
| ] |
|
|
| SD3_DEFAULT_CONFIG = [ |
| { |
| "wildcard_or_filter_func": lambda name: re.search( |
| r"^((?!transformer_blocks\.(1[6-9]|2[0-3])).)*$", name |
| ), |
| "select_cache_step_func": lambda step: (step % 2) != 0, |
| } |
| ] |
|
|
|
|
| def replace_module(parent, name_path, new_module): |
| path_parts = name_path.split(".") |
| for part in path_parts[:-1]: |
| parent = getattr(parent, part) |
| setattr(parent, path_parts[-1], new_module) |
|
|