|
from enum import Enum |
|
from typing import List, Any, Optional, Union, Tuple, Dict |
|
import numpy as np |
|
from modules import scripts, processing, shared |
|
from scripts import global_state |
|
from scripts.processor import preprocessor_sliders_config, model_free_preprocessors |
|
|
|
from modules.api import api |
|
|
|
|
|
def get_api_version() -> int: |
|
return 2 |
|
|
|
|
|
class ControlMode(Enum): |
|
""" |
|
The improved guess mode. |
|
""" |
|
|
|
BALANCED = "Balanced" |
|
PROMPT = "My prompt is more important" |
|
CONTROL = "ControlNet is more important" |
|
|
|
|
|
class ResizeMode(Enum): |
|
""" |
|
Resize modes for ControlNet input images. |
|
""" |
|
|
|
RESIZE = "Just Resize" |
|
INNER_FIT = "Crop and Resize" |
|
OUTER_FIT = "Resize and Fill" |
|
|
|
|
|
resize_mode_aliases = { |
|
'Inner Fit (Scale to Fit)': 'Crop and Resize', |
|
'Outer Fit (Shrink to Fit)': 'Resize and Fill', |
|
'Scale to Fit (Inner Fit)': 'Crop and Resize', |
|
'Envelope (Outer Fit)': 'Resize and Fill', |
|
} |
|
|
|
|
|
def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: |
|
if isinstance(value, str): |
|
return ResizeMode(resize_mode_aliases.get(value, value)) |
|
elif isinstance(value, int): |
|
return [e for e in ResizeMode][value] |
|
else: |
|
return value |
|
|
|
|
|
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: |
|
if isinstance(value, str): |
|
return ControlMode(value) |
|
elif isinstance(value, int): |
|
return [e for e in ControlMode][value] |
|
else: |
|
return value |
|
|
|
|
|
InputImage = Union[np.ndarray, str] |
|
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] |
|
|
|
|
|
class ControlNetUnit: |
|
""" |
|
Represents an entire ControlNet processing unit. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
enabled: bool=True, |
|
module: Optional[str]=None, |
|
model: Optional[str]=None, |
|
weight: float=1.0, |
|
image: Optional[InputImage]=None, |
|
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT, |
|
low_vram: bool=False, |
|
processor_res: int=512, |
|
threshold_a: float=64, |
|
threshold_b: float=64, |
|
guidance_start: float=0.0, |
|
guidance_end: float=1.0, |
|
pixel_perfect: bool=False, |
|
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED, |
|
**_kwargs, |
|
): |
|
self.enabled = enabled |
|
self.module = module |
|
self.model = model |
|
self.weight = weight |
|
self.image = image |
|
self.resize_mode = resize_mode |
|
self.low_vram = low_vram |
|
self.processor_res = processor_res |
|
self.threshold_a = threshold_a |
|
self.threshold_b = threshold_b |
|
self.guidance_start = guidance_start |
|
self.guidance_end = guidance_end |
|
self.pixel_perfect = pixel_perfect |
|
self.control_mode = control_mode |
|
|
|
def __eq__(self, other): |
|
if not isinstance(other, ControlNetUnit): |
|
return False |
|
|
|
return vars(self) == vars(other) |
|
|
|
|
|
def to_base64_nparray(encoding: str): |
|
""" |
|
Convert a base64 image into the image type the extension uses |
|
""" |
|
|
|
return np.array(api.decode_base64_to_image(encoding)).astype('uint8') |
|
|
|
|
|
def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]: |
|
""" |
|
Fetch ControlNet processing units from a StableDiffusionProcessing. |
|
""" |
|
|
|
return get_all_units(p.scripts, p.script_args) |
|
|
|
|
|
def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]: |
|
""" |
|
Fetch ControlNet processing units from an existing script runner. |
|
Use this function to fetch units from the list of all scripts arguments. |
|
""" |
|
|
|
cn_script = find_cn_script(script_runner) |
|
if cn_script: |
|
return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to]) |
|
|
|
return [] |
|
|
|
|
|
def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]: |
|
""" |
|
Fetch ControlNet processing units from ControlNet script arguments. |
|
Use `external_code.get_all_units` to fetch units from the list of all scripts arguments. |
|
""" |
|
|
|
units = [] |
|
i = 0 |
|
while i < len(script_args): |
|
if script_args[i] is not None: |
|
units.append(to_processing_unit(script_args[i])) |
|
i += 1 |
|
|
|
return units |
|
|
|
|
|
def get_single_unit_from(script_args: List[Any], index: int=0) -> Optional[ControlNetUnit]: |
|
""" |
|
Fetch a single ControlNet processing unit from ControlNet script arguments. |
|
The list must not contain script positional arguments. It must only contain processing units. |
|
""" |
|
|
|
i = 0 |
|
while i < len(script_args) and index >= 0: |
|
if index == 0 and script_args[i] is not None: |
|
return to_processing_unit(script_args[i]) |
|
i += 1 |
|
|
|
index -= 1 |
|
|
|
return None |
|
|
|
def get_max_models_num(): |
|
""" |
|
Fetch the maximum number of allowed ControlNet models. |
|
""" |
|
|
|
max_models_num = shared.opts.data.get("control_net_max_models_num", 1) |
|
return max_models_num |
|
|
|
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: |
|
""" |
|
Convert different types to processing unit. |
|
If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. |
|
""" |
|
|
|
ext_compat_keys = { |
|
'guessmode': 'guess_mode', |
|
'guidance': 'guidance_end', |
|
'lowvram': 'low_vram', |
|
'input_image': 'image' |
|
} |
|
|
|
if isinstance(unit, dict): |
|
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} |
|
|
|
mask = None |
|
if 'mask' in unit: |
|
mask = unit['mask'] |
|
del unit['mask'] |
|
|
|
if 'image' in unit and not isinstance(unit['image'], dict): |
|
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit['image'] else None |
|
|
|
if 'guess_mode' in unit: |
|
print('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') |
|
|
|
unit = ControlNetUnit(**unit) |
|
|
|
|
|
|
|
return unit |
|
|
|
|
|
def update_cn_script_in_processing( |
|
p: processing.StableDiffusionProcessing, |
|
cn_units: List[ControlNetUnit], |
|
**_kwargs, |
|
): |
|
""" |
|
Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`. |
|
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. |
|
|
|
Does not update `p.script_args` if any of the folling is true: |
|
- ControlNet is not present in `p.scripts` |
|
- `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet |
|
""" |
|
|
|
cn_units_type = type(cn_units) if type(cn_units) in (list, tuple) else list |
|
script_args = list(p.script_args) |
|
update_cn_script_in_place(p.scripts, script_args, cn_units) |
|
p.script_args = cn_units_type(script_args) |
|
|
|
|
|
def update_cn_script_in_place( |
|
script_runner: scripts.ScriptRunner, |
|
script_args: List[Any], |
|
cn_units: List[ControlNetUnit], |
|
**_kwargs, |
|
): |
|
""" |
|
Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`. |
|
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. |
|
|
|
Does not update `script_args` if any of the folling is true: |
|
- ControlNet is not present in `script_runner` |
|
- `script_args` is not filled with script arguments for scripts that are processed before ControlNet |
|
""" |
|
|
|
cn_script = find_cn_script(script_runner) |
|
if cn_script is None or len(script_args) < cn_script.args_from: |
|
return |
|
|
|
|
|
max_models = shared.opts.data.get("control_net_max_models_num", 1) |
|
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) |
|
|
|
cn_script_args_diff = 0 |
|
for script in script_runner.alwayson_scripts: |
|
if script is cn_script: |
|
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) |
|
script_args[script.args_from:script.args_to] = cn_units |
|
script.args_to = script.args_from + len(cn_units) |
|
else: |
|
script.args_from += cn_script_args_diff |
|
script.args_to += cn_script_args_diff |
|
|
|
|
|
def get_models(update: bool=False) -> List[str]: |
|
""" |
|
Fetch the list of available models. |
|
Each value is a valid candidate of `ControlNetUnit.model`. |
|
|
|
Keyword arguments: |
|
update -- Whether to refresh the list from disk. (default False) |
|
""" |
|
|
|
if update: |
|
global_state.update_cn_models() |
|
|
|
return list(global_state.cn_models_names.values()) |
|
|
|
|
|
def get_modules(alias_names: bool = False) -> List[str]: |
|
""" |
|
Fetch the list of available preprocessors. |
|
Each value is a valid candidate of `ControlNetUnit.module`. |
|
|
|
Keyword arguments: |
|
alias_names -- Whether to get the ui alias names instead of internal keys |
|
""" |
|
|
|
modules = list(global_state.cn_preprocessor_modules.keys()) |
|
|
|
if alias_names: |
|
modules = [global_state.preprocessor_aliases.get(module, module) for module in modules] |
|
|
|
return modules |
|
|
|
|
|
def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]: |
|
""" |
|
get the detail of all preprocessors including |
|
sliders: the slider config in Auto1111 webUI |
|
|
|
Keyword arguments: |
|
alias_names -- Whether to get the module detail with alias names instead of internal keys |
|
""" |
|
|
|
_module_detail = {} |
|
_module_list = get_modules(False) |
|
_module_list_alias = get_modules(True) |
|
|
|
_output_list = _module_list if not alias_names else _module_list_alias |
|
for index, module in enumerate(_output_list): |
|
if _module_list[index] in preprocessor_sliders_config: |
|
_module_detail[module] = { |
|
"model_free": module in model_free_preprocessors, |
|
"sliders": preprocessor_sliders_config[_module_list[index]] |
|
} |
|
else: |
|
_module_detail[module] = { |
|
"model_free": False, |
|
"sliders": [] |
|
} |
|
|
|
return _module_detail |
|
|
|
|
|
def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]: |
|
""" |
|
Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script. |
|
""" |
|
|
|
if script_runner is None: |
|
return None |
|
|
|
for script in script_runner.alwayson_scripts: |
|
if is_cn_script(script): |
|
return script |
|
|
|
|
|
def is_cn_script(script: scripts.Script) -> bool: |
|
""" |
|
Determine whether `script` is a ControlNet script. |
|
""" |
|
|
|
return script.title().lower() == 'controlnet' |
|
|