Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import contextlib | |
| import importlib | |
| import itertools | |
| import logging | |
| import math | |
| import sys | |
| from functools import partial | |
| from typing import TYPE_CHECKING, Callable, NamedTuple | |
| from modules.Utilities import Latent, upscale | |
| import torch.nn.functional as torchf | |
| if TYPE_CHECKING: | |
| from collections.abc import Sequence | |
| from types import ModuleType | |
| try: | |
| from enum import StrEnum | |
| except ImportError: | |
| # Compatibility workaround for pre-3.11 Python versions. | |
| from enum import Enum | |
| class StrEnum(str, Enum): | |
| def _generate_next_value_(name: str, *_unused: list) -> str: | |
| return name.lower() | |
| def __str__(self) -> str: | |
| return str(self.value) | |
| logger = logging.getLogger(__name__) | |
| UPSCALE_METHODS = ("bicubic", "bislerp", "bilinear", "nearest-exact", "nearest", "area") | |
| class TimeMode(StrEnum): | |
| PERCENT = "percent" | |
| TIMESTEP = "timestep" | |
| SIGMA = "sigma" | |
| class ModelType(StrEnum): | |
| SD15 = "SD15" | |
| SDXL = "SDXL" | |
| def parse_blocks(name: str, val: str | Sequence[int]) -> set[tuple[str, int]]: | |
| """#### Parse block definitions. | |
| #### Args: | |
| - `name` (str): The name of the block. | |
| - `val` (Union[str, Sequence[int]]): The block values. | |
| #### Returns: | |
| - `set[tuple[str, int]]`: The parsed blocks. | |
| """ | |
| if isinstance(val, (tuple, list)): | |
| # Handle a sequence passed in via YAML parameters. | |
| if not all(isinstance(item, int) and item >= 0 for item in val): | |
| raise ValueError( | |
| "Bad blocks definition, must be comma separated string or sequence of positive int", | |
| ) | |
| return {(name, item) for item in val} | |
| vals = (rawval.strip() for rawval in val.split(",")) | |
| return {(name, int(val.strip())) for val in vals if val} | |
| def convert_time( | |
| ms: object, | |
| time_mode: TimeMode, | |
| start_time: float, | |
| end_time: float, | |
| ) -> tuple[float, float]: | |
| """#### Convert time based on the mode. | |
| #### Args: | |
| - `ms` (Any): The time object. | |
| - `time_mode` (TimeMode): The time mode. | |
| - `start_time` (float): The start time. | |
| - `end_time` (float): The end time. | |
| #### Returns: | |
| - `Tuple[float, float]`: The converted start and end times. | |
| """ | |
| if time_mode == TimeMode.SIGMA: | |
| return (start_time, end_time) | |
| if time_mode == TimeMode.TIMESTEP: | |
| start_time = 1.0 - (start_time / 999.0) | |
| end_time = 1.0 - (end_time / 999.0) | |
| else: | |
| if start_time > 1.0 or start_time < 0.0: | |
| raise ValueError( | |
| "invalid value for start percent", | |
| ) | |
| if end_time > 1.0 or end_time < 0.0: | |
| raise ValueError( | |
| "invalid value for end percent", | |
| ) | |
| return ( | |
| round(ms.percent_to_sigma(start_time), 4), | |
| round(ms.percent_to_sigma(end_time), 4), | |
| ) | |
| raise ValueError("invalid time mode") | |
| def get_sigma(options: dict, key: str = "sigmas") -> float | None: | |
| """#### Get the sigma value from options. | |
| #### Args: | |
| - `options` (dict): The options dictionary. | |
| - `key` (str, optional): The key to look for. Defaults to "sigmas". | |
| #### Returns: | |
| - `Optional[float]`: The sigma value if found, otherwise None. | |
| """ | |
| if not isinstance(options, dict): | |
| return None | |
| sigmas = options.get(key) | |
| if sigmas is None: | |
| return None | |
| if isinstance(sigmas, float): | |
| return sigmas | |
| return sigmas.detach().cpu().max().item() | |
| def check_time(time_arg: dict | float, start_sigma: float, end_sigma: float) -> bool: | |
| """#### Check if the time is within the sigma range. | |
| #### Args: | |
| - `time_arg` (Union[dict, float]): The time argument. | |
| - `start_sigma` (float): The start sigma. | |
| - `end_sigma` (float): The end sigma. | |
| #### Returns: | |
| - `bool`: Whether the time is within the range. | |
| """ | |
| sigma = get_sigma(time_arg) if not isinstance(time_arg, float) else time_arg | |
| if sigma is None: | |
| return False | |
| return sigma <= start_sigma and sigma >= end_sigma | |
| __block_to_num_map = {"input": 0, "middle": 1, "output": 2} | |
| def block_to_num(block_type: str, block_id: int) -> tuple[int, int]: | |
| """#### Convert block type and id to numerical representation. | |
| #### Args: | |
| - `block_type` (str): The block type. | |
| - `block_id` (int): The block id. | |
| #### Returns: | |
| - `Tuple[int, int]`: The numerical representation of the block. | |
| """ | |
| type_id = __block_to_num_map.get(block_type) | |
| if type_id is None: | |
| errstr = f"Got unexpected block type {block_type}!" | |
| raise ValueError(errstr) | |
| return (type_id, block_id) | |
| # Naive and totally inaccurate way to factorize target_res into rescaled integer width/height | |
| def rescale_size( | |
| width: int, | |
| height: int, | |
| target_res: int, | |
| *, | |
| tolerance=1, | |
| ) -> tuple[int, int]: | |
| """#### Rescale size to fit target resolution. | |
| #### Args: | |
| - `width` (int): The width. | |
| - `height` (int): The height. | |
| - `target_res` (int): The target resolution. | |
| - `tolerance` (int, optional): The tolerance. Defaults to 1. | |
| #### Returns: | |
| - `Tuple[int, int]`: The rescaled width and height. | |
| """ | |
| tolerance = min(target_res, tolerance) | |
| def get_neighbors(num: float): | |
| if num < 1: | |
| return None | |
| numi = int(num) | |
| return tuple( | |
| numi + adj | |
| for adj in sorted( | |
| range( | |
| -min(numi - 1, tolerance), | |
| tolerance + 1 + math.ceil(num - numi), | |
| ), | |
| key=abs, | |
| ) | |
| ) | |
| scale = math.sqrt(height * width / target_res) | |
| height_scaled, width_scaled = height / scale, width / scale | |
| height_rounded = get_neighbors(height_scaled) | |
| width_rounded = get_neighbors(width_scaled) | |
| for h, w in itertools.zip_longest(height_rounded, width_rounded): | |
| h_adj = target_res / w if w is not None else 0.1 | |
| if h_adj % 1 == 0: | |
| return (w, int(h_adj)) | |
| if h is None: | |
| continue | |
| w_adj = target_res / h | |
| if w_adj % 1 == 0: | |
| return (int(w_adj), h) | |
| msg = f"Can't rescale {width} and {height} to fit {target_res}" | |
| raise ValueError(msg) | |
| def guess_model_type(model: object) -> ModelType | None: | |
| """#### Guess the model type. | |
| #### Args: | |
| - `model` (object): The model object. | |
| #### Returns: | |
| - `Optional[ModelType]`: The guessed model type. | |
| """ | |
| latent_format = model.get_model_object("latent_format") | |
| if isinstance(latent_format, Latent.SD15): | |
| return ModelType.SD15 | |
| return None | |
| def sigma_to_pct(ms, sigma): | |
| """#### Convert sigma to percentage. | |
| #### Args: | |
| - `ms` (Any): The time object. | |
| - `sigma` (float): The sigma value. | |
| #### Returns: | |
| - `float`: The percentage. | |
| """ | |
| return (1.0 - (ms.timestep(sigma).detach().cpu() / 999.0)).clamp(0.0, 1.0).item() | |
| def fade_scale( | |
| pct, | |
| start_pct=0.0, | |
| end_pct=1.0, | |
| fade_start=1.0, | |
| fade_cap=0.0, | |
| ): | |
| """#### Calculate the fade scale. | |
| #### Args: | |
| - `pct` (float): The percentage. | |
| - `start_pct` (float, optional): The start percentage. Defaults to 0.0. | |
| - `end_pct` (float, optional): The end percentage. Defaults to 1.0. | |
| - `fade_start` (float, optional): The fade start. Defaults to 1.0. | |
| - `fade_cap` (float, optional): The fade cap. Defaults to 0.0. | |
| #### Returns: | |
| - `float`: The fade scale. | |
| """ | |
| if not (start_pct <= pct <= end_pct) or start_pct > end_pct: | |
| return 0.0 | |
| if pct < fade_start: | |
| return 1.0 | |
| scaling_pct = 1.0 - ((pct - fade_start) / (end_pct - fade_start)) | |
| return max(fade_cap, scaling_pct) | |
| def scale_samples( | |
| samples, | |
| width, | |
| height, | |
| mode="bicubic", | |
| sigma=None, # noqa: ARG001 | |
| ): | |
| """#### Scale samples to the specified width and height. | |
| #### Args: | |
| - `samples` (torch.Tensor): The input samples. | |
| - `width` (int): The target width. | |
| - `height` (int): The target height. | |
| - `mode` (str, optional): The scaling mode. Defaults to "bicubic". | |
| - `sigma` (Optional[float], optional): The sigma value. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The scaled samples. | |
| """ | |
| if mode == "bislerp": | |
| return upscale.bislerp(samples, width, height) | |
| return torchf.interpolate(samples, size=(height, width), mode=mode) | |
| class Integrations: | |
| """#### Class for managing integrations.""" | |
| class Integration(NamedTuple): | |
| key: str | |
| module_name: str | |
| handler: Callable | None = None | |
| def __init__(self): | |
| """#### Initialize the Integrations class.""" | |
| self.initialized = False | |
| self.modules = {} | |
| self.init_handlers = [] | |
| self.handlers = [] | |
| def __getitem__(self, key): | |
| """#### Get a module by key. | |
| #### Args: | |
| - `key` (str): The key. | |
| #### Returns: | |
| - `ModuleType`: The module. | |
| """ | |
| return self.modules[key] | |
| def __contains__(self, key): | |
| """#### Check if a module is in the integrations. | |
| #### Args: | |
| - `key` (str): The key. | |
| #### Returns: | |
| - `bool`: Whether the module is in the integrations. | |
| """ | |
| return key in self.modules | |
| def __getattr__(self, key): | |
| """#### Get a module by attribute. | |
| #### Args: | |
| - `key` (str): The key. | |
| #### Returns: | |
| - `Optional[ModuleType]`: The module if found, otherwise None. | |
| """ | |
| return self.modules.get(key) | |
| def get_custom_node(name: str) -> ModuleType | None: | |
| """#### Get a custom node by name. | |
| #### Args: | |
| - `name` (str): The name of the custom node. | |
| #### Returns: | |
| - `Optional[ModuleType]`: The custom node if found, otherwise None. | |
| """ | |
| module_key = f"custom_nodes.{name}" | |
| with contextlib.suppress(StopIteration): | |
| spec = importlib.util.find_spec(module_key) | |
| if spec is None: | |
| return None | |
| return next( | |
| v | |
| for v in sys.modules.copy().values() | |
| if hasattr(v, "__spec__") | |
| and v.__spec__ is not None | |
| and v.__spec__.origin == spec.origin | |
| ) | |
| return None | |
| def register_init_handler(self, handler): | |
| """#### Register an initialization handler. | |
| #### Args: | |
| - `handler` (Callable): The handler. | |
| """ | |
| self.init_handlers.append(handler) | |
| def register_integration(self, key: str, module_name: str, handler=None) -> None: | |
| """#### Register an integration. | |
| #### Args: | |
| - `key` (str): The key. | |
| - `module_name` (str): The module name. | |
| - `handler` (Optional[Callable], optional): The handler. Defaults to None. | |
| """ | |
| if self.initialized: | |
| raise ValueError( | |
| "Internal error: Cannot register integration after initialization", | |
| ) | |
| if any(item[0] == key or item[1] == module_name for item in self.handlers): | |
| errstr = ( | |
| f"Module {module_name} ({key}) already in integration handlers list!" | |
| ) | |
| raise ValueError(errstr) | |
| self.handlers.append(self.Integration(key, module_name, handler)) | |
| def initialize(self) -> None: | |
| """#### Initialize the integrations.""" | |
| if self.initialized: | |
| return | |
| self.initialized = True | |
| for ih in self.handlers: | |
| module = self.get_custom_node(ih.module_name) | |
| if module is None: | |
| continue | |
| if ih.handler is not None: | |
| module = ih.handler(module) | |
| if module is not None: | |
| self.modules[ih.key] = module | |
| for init_handler in self.init_handlers: | |
| init_handler(self) | |
| class JHDIntegrations(Integrations): | |
| """#### Class for managing JHD integrations.""" | |
| def __init__(self, *args: list, **kwargs: dict): | |
| """#### Initialize the JHDIntegrations class.""" | |
| super().__init__(*args, **kwargs) | |
| self.register_integration("bleh", "ComfyUI-bleh", self.bleh_integration) | |
| self.register_integration("freeu_advanced", "FreeU_Advanced") | |
| def bleh_integration(cls, bleh: ModuleType) -> ModuleType | None: | |
| """#### Integrate with BLEH. | |
| #### Args: | |
| - `bleh` (ModuleType): The BLEH module. | |
| #### Returns: | |
| - `Optional[ModuleType]`: The integrated BLEH module if successful, otherwise None. | |
| """ | |
| bleh_version = getattr(bleh, "BLEH_VERSION", -1) | |
| if bleh_version < 0: | |
| return None | |
| return bleh | |
| MODULES = JHDIntegrations() | |
| class IntegratedNode(type): | |
| """#### Metaclass for integrated nodes.""" | |
| def wrap_INPUT_TYPES(orig_method: Callable, *args: list, **kwargs: dict) -> dict: | |
| """#### Wrap the INPUT_TYPES method to initialize modules. | |
| #### Args: | |
| - `orig_method` (Callable): The original method. | |
| - `args` (list): The arguments. | |
| - `kwargs` (dict): The keyword arguments. | |
| #### Returns: | |
| - `dict`: The result of the original method. | |
| """ | |
| MODULES.initialize() | |
| return orig_method(*args, **kwargs) | |
| def __new__(cls: type, name: str, bases: tuple, attrs: dict) -> object: | |
| """#### Create a new instance of the class. | |
| #### Args: | |
| - `name` (str): The name of the class. | |
| - `bases` (tuple): The base classes. | |
| - `attrs` (dict): The attributes. | |
| #### Returns: | |
| - `object`: The new instance. | |
| """ | |
| obj = type.__new__(cls, name, bases, attrs) | |
| if hasattr(obj, "INPUT_TYPES"): | |
| obj.INPUT_TYPES = partial(cls.wrap_INPUT_TYPES, obj.INPUT_TYPES) | |
| return obj | |
| def init_integrations(integrations) -> None: | |
| """#### Initialize integrations. | |
| #### Args: | |
| - `integrations` (Integrations): The integrations object. | |
| """ | |
| global scale_samples, UPSCALE_METHODS # noqa: PLW0603 | |
| ext_bleh = integrations.bleh | |
| if ext_bleh is None: | |
| return | |
| bleh_latentutils = getattr(ext_bleh.py, "latent_utils", None) | |
| if bleh_latentutils is None: | |
| return | |
| bleh_version = getattr(ext_bleh, "BLEH_VERSION", -1) | |
| UPSCALE_METHODS = bleh_latentutils.UPSCALE_METHODS | |
| if bleh_version >= 0: | |
| scale_samples = bleh_latentutils.scale_samples | |
| return | |
| def scale_samples_wrapped(*args: list, sigma=None, **kwargs: dict): # noqa: ARG001 | |
| """#### Wrap the scale_samples method. | |
| #### Args: | |
| - `args` (list): The arguments. | |
| - `sigma` (Optional[float], optional): The sigma value. Defaults to None. | |
| - `kwargs` (dict): The keyword arguments. | |
| #### Returns: | |
| - `Any`: The result of the scale_samples method. | |
| """ | |
| return bleh_latentutils.scale_samples(*args, **kwargs) | |
| scale_samples = scale_samples_wrapped | |
| MODULES.register_init_handler(init_integrations) | |
| __all__ = ( | |
| "UPSCALE_METHODS", | |
| "check_time", | |
| "convert_time", | |
| "get_sigma", | |
| "guess_model_type", | |
| "parse_blocks", | |
| "rescale_size", | |
| "scale_samples", | |
| ) |