Spaces:
Running
on
Zero
Running
on
Zero
| import csv | |
| import shutil | |
| from pathlib import Path | |
| import folder_paths | |
| import torch | |
| from ..log import log | |
| from ..utils import here | |
| Conditioning = list[tuple[torch.Tensor, dict[str, torch.Tensor]]] | |
| def check_condition(conditioning: Conditioning): | |
| has_cn = False | |
| if len(conditioning) > 1: | |
| log.warn( | |
| "More than one conditioning was provided. Only the first one will be used." | |
| ) | |
| first = conditioning[0] | |
| cond, kwargs = first | |
| log.debug("Conditioning Shape") | |
| log.debug(cond.shape) | |
| log.debug("Conditioning keys") | |
| log.debug([f"\t{k} - {type(kwargs[k])}" for k in kwargs]) | |
| if "control" in kwargs: | |
| log.debug("Conditioning contains a controlnet") | |
| has_cn = True | |
| if "pooled_output" not in kwargs: | |
| raise ValueError( | |
| "Conditioning is not valid. Missing 'pooled_output' key." | |
| ) | |
| return has_cn | |
| class MTB_InterpolateCondition: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "blend": ( | |
| "FLOAT", | |
| {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}, | |
| ), | |
| }, | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| CATEGORY = "mtb/conditioning" | |
| FUNCTION = "execute" | |
| def execute( | |
| self, blend: float, **kwargs: Conditioning | |
| ) -> tuple[Conditioning]: | |
| blend = max(0.0, min(1.0, blend)) | |
| conditions: list[Conditioning] = list(kwargs.values()) | |
| num_conditions = len(conditions) | |
| if num_conditions < 2: | |
| raise ValueError("At least two conditioning inputs are required.") | |
| segment_length = 1.0 / (num_conditions - 1) | |
| segment_index = min(int(blend // segment_length), num_conditions - 2) | |
| local_blend = ( | |
| blend - (segment_index * segment_length) | |
| ) / segment_length | |
| cond_from = conditions[segment_index] | |
| cond_to = conditions[segment_index + 1] | |
| from_cn = check_condition(cond_from) | |
| to_cn = check_condition(cond_to) | |
| if from_cn and to_cn: | |
| raise ValueError( | |
| "Interpolating conditions cannot both contain ControlNets" | |
| ) | |
| try: | |
| interpolated_condition = [ | |
| (1.0 - local_blend) * c_from + local_blend * c_to | |
| for c_from, c_to in zip( | |
| cond_from[0][0], cond_to[0][0], strict=False | |
| ) | |
| ] | |
| except Exception as e: | |
| print(f"Error during interpolation: {e}") | |
| raise | |
| pooled_from = cond_from[0][1].get( | |
| "pooled_output", | |
| torch.zeros_like( | |
| next(iter(cond_from[0][1].values()), torch.tensor([])) | |
| ), | |
| ) | |
| pooled_to = cond_to[0][1].get( | |
| "pooled_output", | |
| torch.zeros_like( | |
| next(iter(cond_from[0][1].values()), torch.tensor([])) | |
| ), | |
| ) | |
| interpolated_pooled = ( | |
| 1.0 - local_blend | |
| ) * pooled_from + local_blend * pooled_to | |
| res = {"pooled_output": interpolated_pooled} | |
| if from_cn: | |
| res["control"] = cond_from[0][1]["control"] | |
| res["control_apply_to_uncond"] = cond_from[0][1][ | |
| "control_apply_to_uncond" | |
| ] | |
| if to_cn: | |
| res["control"] = cond_to[0][1]["control"] | |
| res["control_apply_to_uncond"] = cond_to[0][1][ | |
| "control_apply_to_uncond" | |
| ] | |
| return ([(torch.stack(interpolated_condition), res)],) | |
| class MTB_InterpolateClipSequential: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "base_text": ("STRING", {"multiline": True}), | |
| "text_to_replace": ("STRING", {"default": ""}), | |
| "clip": ("CLIP",), | |
| "interpolation_strength": ( | |
| "FLOAT", | |
| {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, | |
| ), | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| FUNCTION = "interpolate_encodings_sequential" | |
| CATEGORY = "mtb/conditioning" | |
| def interpolate_encodings_sequential( | |
| self, | |
| base_text, | |
| text_to_replace, | |
| clip, | |
| interpolation_strength, | |
| **replacements, | |
| ): | |
| log.debug(f"Received interpolation_strength: {interpolation_strength}") | |
| # - Ensure interpolation strength is within [0, 1] | |
| interpolation_strength = max(0.0, min(1.0, interpolation_strength)) | |
| # - Check if replacements were provided | |
| if not replacements: | |
| raise ValueError("At least one replacement should be provided.") | |
| num_replacements = len(replacements) | |
| log.debug(f"Number of replacements: {num_replacements}") | |
| segment_length = 1.0 / num_replacements | |
| log.debug(f"Calculated segment_length: {segment_length}") | |
| # - Find the segment that the interpolation_strength falls into | |
| segment_index = min( | |
| int(interpolation_strength // segment_length), num_replacements - 1 | |
| ) | |
| log.debug(f"Segment index: {segment_index}") | |
| # - Calculate the local strength within the segment | |
| local_strength = ( | |
| interpolation_strength - (segment_index * segment_length) | |
| ) / segment_length | |
| log.debug(f"Local strength: {local_strength}") | |
| # - If it's the first segment, interpolate between base_text and the first replacement | |
| if segment_index == 0: | |
| replacement_text = list(replacements.values())[0] | |
| log.debug("Using the base text a the base blend") | |
| # - Start with the base_text condition | |
| tokens = clip.tokenize(base_text) | |
| cond_from, pooled_from = clip.encode_from_tokens( | |
| tokens, return_pooled=True | |
| ) | |
| else: | |
| base_replace = list(replacements.values())[segment_index - 1] | |
| log.debug(f"Using {base_replace} a the base blend") | |
| # - Start with the base_text condition replaced by the closest replacement | |
| tokens = clip.tokenize( | |
| base_text.replace(text_to_replace, base_replace) | |
| ) | |
| cond_from, pooled_from = clip.encode_from_tokens( | |
| tokens, return_pooled=True | |
| ) | |
| replacement_text = list(replacements.values())[segment_index] | |
| interpolated_text = base_text.replace( | |
| text_to_replace, replacement_text | |
| ) | |
| tokens = clip.tokenize(interpolated_text) | |
| cond_to, pooled_to = clip.encode_from_tokens( | |
| tokens, return_pooled=True | |
| ) | |
| # - Linearly interpolate between the two conditions | |
| interpolated_condition = ( | |
| 1.0 - local_strength | |
| ) * cond_from + local_strength * cond_to | |
| interpolated_pooled = ( | |
| 1.0 - local_strength | |
| ) * pooled_from + local_strength * pooled_to | |
| return ( | |
| [[interpolated_condition, {"pooled_output": interpolated_pooled}]], | |
| ) | |
| class MTB_SmartStep: | |
| """Utils to control the steps start/stop of the KAdvancedSampler in percentage""" | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "step": ( | |
| "INT", | |
| {"default": 20, "min": 1, "max": 10000, "step": 1}, | |
| ), | |
| "start_percent": ( | |
| "INT", | |
| {"default": 0, "min": 0, "max": 100, "step": 1}, | |
| ), | |
| "end_percent": ( | |
| "INT", | |
| {"default": 0, "min": 0, "max": 100, "step": 1}, | |
| ), | |
| } | |
| } | |
| RETURN_TYPES = ("INT", "INT", "INT") | |
| RETURN_NAMES = ("step", "start", "end") | |
| FUNCTION = "do_step" | |
| CATEGORY = "mtb/conditioning" | |
| def do_step(self, step, start_percent, end_percent): | |
| start = int(step * start_percent / 100) | |
| end = int(step * end_percent / 100) | |
| return (step, start, end) | |
| def install_default_styles(force=False): | |
| styles_dir = Path(folder_paths.base_path) / "styles" | |
| styles_dir.mkdir(parents=True, exist_ok=True) | |
| default_style = here / "styles.csv" | |
| dest_style = styles_dir / "default.csv" | |
| if force or not dest_style.exists(): | |
| log.debug(f"Copying default style to {dest_style}") | |
| shutil.copy2(default_style.as_posix(), dest_style.as_posix()) | |
| return dest_style | |
| class MTB_StylesLoader: | |
| """Load csv files and populate a dropdown from the rows (à la A111)""" | |
| options = {} | |
| def INPUT_TYPES(cls): | |
| if not cls.options: | |
| input_dir = Path(folder_paths.base_path) / "styles" | |
| if not input_dir.exists(): | |
| install_default_styles() | |
| if not ( | |
| files := [f for f in input_dir.iterdir() if f.suffix == ".csv"] | |
| ): | |
| log.warn( | |
| "No styles found in the styles folder, place at least one csv file in the styles folder at the root of ComfyUI (for instance ComfyUI/styles/mystyle.csv)" | |
| ) | |
| for file in files: | |
| with open(file, encoding="utf8") as f: | |
| parsed = csv.reader(f) | |
| for i, row in enumerate(parsed): | |
| # log.debug(f"Adding style {row[0]}") | |
| try: | |
| name, positive, negative = (row + [None] * 3)[:3] | |
| positive = positive or "" | |
| negative = negative or "" | |
| if name is not None: | |
| cls.options[name] = (positive, negative) | |
| else: | |
| # Handle the case where 'name' is None | |
| log.warning(f"Missing 'name' in row {i}.") | |
| except Exception as e: | |
| log.warning( | |
| f"There was an error while parsing {file}, make sure it respects A1111 format, i.e 3 columns name, positive, negative:\n{e}" | |
| ) | |
| continue | |
| else: | |
| log.debug(f"Using cached styles (count: {len(cls.options)})") | |
| return { | |
| "required": { | |
| "style_name": (list(cls.options.keys()),), | |
| } | |
| } | |
| CATEGORY = "mtb/conditioning" | |
| RETURN_TYPES = ("STRING", "STRING") | |
| RETURN_NAMES = ("positive", "negative") | |
| FUNCTION = "load_style" | |
| def load_style(self, style_name): | |
| return (self.options[style_name][0], self.options[style_name][1]) | |
| __nodes__ = [ | |
| MTB_SmartStep, | |
| MTB_StylesLoader, | |
| MTB_InterpolateClipSequential, | |
| MTB_InterpolateCondition, | |
| ] | |