|
import csv |
|
import shutil |
|
from pathlib import Path |
|
|
|
import folder_paths |
|
import torch |
|
|
|
from ..log import log |
|
from ..utils import here |
|
|
|
Conditioning = list[tuple[torch.Tensor, dict[str, torch.Tensor]]] |
|
|
|
|
|
def check_condition(conditioning: Conditioning): |
|
has_cn = False |
|
if len(conditioning) > 1: |
|
log.warn( |
|
"More than one conditioning was provided. Only the first one will be used." |
|
) |
|
first = conditioning[0] |
|
cond, kwargs = first |
|
|
|
log.debug("Conditioning Shape") |
|
log.debug(cond.shape) |
|
log.debug("Conditioning keys") |
|
log.debug([f"\t{k} - {type(kwargs[k])}" for k in kwargs]) |
|
if "control" in kwargs: |
|
log.debug("Conditioning contains a controlnet") |
|
has_cn = True |
|
if "pooled_output" not in kwargs: |
|
raise ValueError( |
|
"Conditioning is not valid. Missing 'pooled_output' key." |
|
) |
|
return has_cn |
|
|
|
|
|
class MTB_InterpolateCondition: |
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"blend": ( |
|
"FLOAT", |
|
{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}, |
|
), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("CONDITIONING",) |
|
CATEGORY = "mtb/conditioning" |
|
FUNCTION = "execute" |
|
|
|
def execute( |
|
self, blend: float, **kwargs: Conditioning |
|
) -> tuple[Conditioning]: |
|
blend = max(0.0, min(1.0, blend)) |
|
|
|
conditions: list[Conditioning] = list(kwargs.values()) |
|
num_conditions = len(conditions) |
|
|
|
if num_conditions < 2: |
|
raise ValueError("At least two conditioning inputs are required.") |
|
|
|
segment_length = 1.0 / (num_conditions - 1) |
|
|
|
segment_index = min(int(blend // segment_length), num_conditions - 2) |
|
|
|
local_blend = ( |
|
blend - (segment_index * segment_length) |
|
) / segment_length |
|
|
|
cond_from = conditions[segment_index] |
|
cond_to = conditions[segment_index + 1] |
|
|
|
from_cn = check_condition(cond_from) |
|
to_cn = check_condition(cond_to) |
|
|
|
if from_cn and to_cn: |
|
raise ValueError( |
|
"Interpolating conditions cannot both contain ControlNets" |
|
) |
|
|
|
try: |
|
interpolated_condition = [ |
|
(1.0 - local_blend) * c_from + local_blend * c_to |
|
for c_from, c_to in zip( |
|
cond_from[0][0], cond_to[0][0], strict=False |
|
) |
|
] |
|
except Exception as e: |
|
print(f"Error during interpolation: {e}") |
|
raise |
|
|
|
pooled_from = cond_from[0][1].get( |
|
"pooled_output", |
|
torch.zeros_like( |
|
next(iter(cond_from[0][1].values()), torch.tensor([])) |
|
), |
|
) |
|
|
|
pooled_to = cond_to[0][1].get( |
|
"pooled_output", |
|
torch.zeros_like( |
|
next(iter(cond_from[0][1].values()), torch.tensor([])) |
|
), |
|
) |
|
|
|
interpolated_pooled = ( |
|
1.0 - local_blend |
|
) * pooled_from + local_blend * pooled_to |
|
|
|
res = {"pooled_output": interpolated_pooled} |
|
|
|
if from_cn: |
|
res["control"] = cond_from[0][1]["control"] |
|
res["control_apply_to_uncond"] = cond_from[0][1][ |
|
"control_apply_to_uncond" |
|
] |
|
if to_cn: |
|
res["control"] = cond_to[0][1]["control"] |
|
res["control_apply_to_uncond"] = cond_to[0][1][ |
|
"control_apply_to_uncond" |
|
] |
|
|
|
return ([(torch.stack(interpolated_condition), res)],) |
|
|
|
|
|
class MTB_InterpolateClipSequential: |
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"base_text": ("STRING", {"multiline": True}), |
|
"text_to_replace": ("STRING", {"default": ""}), |
|
"clip": ("CLIP",), |
|
"interpolation_strength": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, |
|
), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("CONDITIONING",) |
|
FUNCTION = "interpolate_encodings_sequential" |
|
|
|
CATEGORY = "mtb/conditioning" |
|
|
|
def interpolate_encodings_sequential( |
|
self, |
|
base_text, |
|
text_to_replace, |
|
clip, |
|
interpolation_strength, |
|
**replacements, |
|
): |
|
log.debug(f"Received interpolation_strength: {interpolation_strength}") |
|
|
|
|
|
interpolation_strength = max(0.0, min(1.0, interpolation_strength)) |
|
|
|
|
|
if not replacements: |
|
raise ValueError("At least one replacement should be provided.") |
|
|
|
num_replacements = len(replacements) |
|
log.debug(f"Number of replacements: {num_replacements}") |
|
|
|
segment_length = 1.0 / num_replacements |
|
log.debug(f"Calculated segment_length: {segment_length}") |
|
|
|
|
|
segment_index = min( |
|
int(interpolation_strength // segment_length), num_replacements - 1 |
|
) |
|
log.debug(f"Segment index: {segment_index}") |
|
|
|
|
|
local_strength = ( |
|
interpolation_strength - (segment_index * segment_length) |
|
) / segment_length |
|
log.debug(f"Local strength: {local_strength}") |
|
|
|
|
|
if segment_index == 0: |
|
replacement_text = list(replacements.values())[0] |
|
log.debug("Using the base text a the base blend") |
|
|
|
tokens = clip.tokenize(base_text) |
|
cond_from, pooled_from = clip.encode_from_tokens( |
|
tokens, return_pooled=True |
|
) |
|
else: |
|
base_replace = list(replacements.values())[segment_index - 1] |
|
log.debug(f"Using {base_replace} a the base blend") |
|
|
|
|
|
tokens = clip.tokenize( |
|
base_text.replace(text_to_replace, base_replace) |
|
) |
|
cond_from, pooled_from = clip.encode_from_tokens( |
|
tokens, return_pooled=True |
|
) |
|
|
|
replacement_text = list(replacements.values())[segment_index] |
|
|
|
interpolated_text = base_text.replace( |
|
text_to_replace, replacement_text |
|
) |
|
tokens = clip.tokenize(interpolated_text) |
|
cond_to, pooled_to = clip.encode_from_tokens( |
|
tokens, return_pooled=True |
|
) |
|
|
|
|
|
interpolated_condition = ( |
|
1.0 - local_strength |
|
) * cond_from + local_strength * cond_to |
|
interpolated_pooled = ( |
|
1.0 - local_strength |
|
) * pooled_from + local_strength * pooled_to |
|
|
|
return ( |
|
[[interpolated_condition, {"pooled_output": interpolated_pooled}]], |
|
) |
|
|
|
|
|
class MTB_SmartStep: |
|
"""Utils to control the steps start/stop of the KAdvancedSampler in percentage""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"step": ( |
|
"INT", |
|
{"default": 20, "min": 1, "max": 10000, "step": 1}, |
|
), |
|
"start_percent": ( |
|
"INT", |
|
{"default": 0, "min": 0, "max": 100, "step": 1}, |
|
), |
|
"end_percent": ( |
|
"INT", |
|
{"default": 0, "min": 0, "max": 100, "step": 1}, |
|
), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("INT", "INT", "INT") |
|
RETURN_NAMES = ("step", "start", "end") |
|
FUNCTION = "do_step" |
|
CATEGORY = "mtb/conditioning" |
|
|
|
def do_step(self, step, start_percent, end_percent): |
|
start = int(step * start_percent / 100) |
|
end = int(step * end_percent / 100) |
|
|
|
return (step, start, end) |
|
|
|
|
|
def install_default_styles(force=False): |
|
styles_dir = Path(folder_paths.base_path) / "styles" |
|
styles_dir.mkdir(parents=True, exist_ok=True) |
|
default_style = here / "styles.csv" |
|
dest_style = styles_dir / "default.csv" |
|
|
|
if force or not dest_style.exists(): |
|
log.debug(f"Copying default style to {dest_style}") |
|
shutil.copy2(default_style.as_posix(), dest_style.as_posix()) |
|
|
|
return dest_style |
|
|
|
|
|
class MTB_StylesLoader: |
|
"""Load csv files and populate a dropdown from the rows (à la A111)""" |
|
|
|
options = {} |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
if not cls.options: |
|
input_dir = Path(folder_paths.base_path) / "styles" |
|
if not input_dir.exists(): |
|
install_default_styles() |
|
|
|
if not ( |
|
files := [f for f in input_dir.iterdir() if f.suffix == ".csv"] |
|
): |
|
log.warn( |
|
"No styles found in the styles folder, place at least one csv file in the styles folder at the root of ComfyUI (for instance ComfyUI/styles/mystyle.csv)" |
|
) |
|
|
|
for file in files: |
|
with open(file, encoding="utf8") as f: |
|
parsed = csv.reader(f) |
|
for i, row in enumerate(parsed): |
|
|
|
try: |
|
name, positive, negative = (row + [None] * 3)[:3] |
|
positive = positive or "" |
|
negative = negative or "" |
|
if name is not None: |
|
cls.options[name] = (positive, negative) |
|
else: |
|
|
|
log.warning(f"Missing 'name' in row {i}.") |
|
|
|
except Exception as e: |
|
log.warning( |
|
f"There was an error while parsing {file}, make sure it respects A1111 format, i.e 3 columns name, positive, negative:\n{e}" |
|
) |
|
continue |
|
|
|
else: |
|
log.debug(f"Using cached styles (count: {len(cls.options)})") |
|
|
|
return { |
|
"required": { |
|
"style_name": (list(cls.options.keys()),), |
|
} |
|
} |
|
|
|
CATEGORY = "mtb/conditioning" |
|
|
|
RETURN_TYPES = ("STRING", "STRING") |
|
RETURN_NAMES = ("positive", "negative") |
|
FUNCTION = "load_style" |
|
|
|
def load_style(self, style_name): |
|
return (self.options[style_name][0], self.options[style_name][1]) |
|
|
|
|
|
__nodes__ = [ |
|
MTB_SmartStep, |
|
MTB_StylesLoader, |
|
MTB_InterpolateClipSequential, |
|
MTB_InterpolateCondition, |
|
] |
|
|