# This code is released without a license. import yaml class TagSwap: """ Accepts a comma-delimited prompt and YAML input. Operates in two modes: In replace mode, replaces tags if they are present or injects them if they are absent. In select mode, adds a tag if the tag is present. Injects if it is absent. The YAML format is: select|replace: - tag_name: p: present action a: absent action """ @classmethod def INPUT_TYPES(s): return { "required": { "tags": ("STRING", {"default": '', "multiline": True, "forceInput": True}), "rules": ("STRING", {"default": '', "multiline": True }) } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("replacements",) FUNCTION = "apply" CATEGORY = "utils" def __init__(self): pass def replace(self, needles, haystack): for tag in haystack: if tag in needles: if 'p' in needles[tag]: yield needles[tag]['p'] else: yield tag for needle, actions in needles.items(): if 'a' in actions and needle not in haystack: yield actions['a'] def select(self, needles, haystack): for tag in needles: if tag in haystack: if 'p' in needles[tag]: yield needles[tag]['p'] else: if 'a' in needles[tag]: yield needles[tag]['a'] def apply(self, tags, rules): haystack = [ tag.strip() for tag in tags.split(',') ] input = yaml.safe_load(rules) if 'replace' in input: needles = input['replace'] return ( ', '.join(list(self.replace(needles, haystack))), ) if 'select' in input: needles = input['select'] return ( ', '.join(list(self.select(needles, haystack))), ) raise Exception("Must use either 'replace' or 'select'") from collections import OrderedDict class PromptMerge: """ Takes a list of prompts. Merges identical, adjacent prompts. Outputs a format compatible with BatchPromptScheudle. """ @classmethod def INPUT_TYPES(s): return { "required": { "prompts": ("STRING",{"default": [], "forceInput": True}), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("prompt",) INPUT_IS_LIST = True FUNCTION = "apply" CATEGORY = "utils" def apply(self, prompts): print("Called with %s"%prompts) merged = OrderedDict() last = None for i, prompt in enumerate(prompts): if prompt != last: merged[i] = prompt last = prompt travel = [ """ "%d": "%s" """ % (index, prompt) for (index, prompt) in merged.items() ] print(travel) return (', '.join(travel),) NODE_CLASS_MAPPINGS = { "TagSwap": TagSwap, "PromptMerge": PromptMerge, } # A dictionary that contains the friendly/humanly readable titles for the nodes NODE_DISPLAY_NAME_MAPPINGS = { "TagSwap": "Tag Swap", "PromptMerge": "Prompt Merge", }