import os import random import inspect import pathlib import re import time import glob from random import choices import yaml import modules.scripts as scripts import modules.images as images import gradio as gr from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state from modules import scripts, script_callbacks, shared from modules.styles import StyleDatabase import modules.textual_inversion.textual_inversion from modules.sd_samplers import samplers, samplers_for_img2img ALL_KEY = 'all yaml files' def get_index(items, item): try: return items.index(item) except Exception: return None def parse_tag(tag): return tag.replace("__", "").replace('<', '').replace('>', '').strip() def read_file_lines(file): f_lines = file.read().splitlines() lines = [] for line in f_lines: line = line.strip() # Check if we have a line that is not empty or starts with an # if line and not line.startswith('#'): lines.append(line) return lines # Wildcards class TagLoader: files = [] wildcard_location = os.path.join( pathlib.Path(inspect.getfile(lambda: None)).parent.parent, "wildcards") loaded_tags = {} missing_tags = set() def load_tags(self, file_path, verbose=False, cache_files=True): if cache_files and self.loaded_tags.get(file_path): return self.loaded_tags.get(file_path) txt_file_path = os.path.join(self.wildcard_location, f'{file_path}.txt') yaml_file_path = os.path.join(self.wildcard_location, f'{file_path}.yaml') if (file_path == ALL_KEY): key = ALL_KEY else: key = file_path.lower() if self.wildcard_location and os.path.isfile(txt_file_path): with open(txt_file_path, encoding="utf8") as file: self.files.append(f"{file_path}.txt") self.loaded_tags[key] = read_file_lines(file) if key is ALL_KEY and self.wildcard_location: files = glob.glob(os.path.join(self.wildcard_location, '**/*.yaml'), recursive=True) output = {} for file in files: with open(file, encoding="utf8") as file: self.files.append(f"{file_path}.yaml") try: data = yaml.safe_load(file) for item in data: if (hasattr(output, item) and verbose): print(f"Duplicate key {item} in {file}") output[item] = { x.lower().strip() for i, x in enumerate(data[item]['Tags']) } except yaml.YAMLError as exc: print(exc) self.loaded_tags[key] = output if self.wildcard_location and os.path.isfile(yaml_file_path): with open(yaml_file_path, encoding="utf8") as file: self.files.append(f"{file_path}.yaml") try: data = yaml.safe_load(file) output = {} for item in data: output[item] = { x.lower().strip() for i, x in enumerate(data[item]['Tags']) } self.loaded_tags[key] = output except yaml.YAMLError as exc: print(exc) if not os.path.isfile(yaml_file_path) and not os.path.isfile( txt_file_path): self.missing_tags.add(file_path) return self.loaded_tags.get(key) if self.loaded_tags.get( key) else [] # notation class TagSelector: def __init__(self, tag_loader, options): self.tag_loader = tag_loader self.previously_selected_tags = {} self.selected_options = dict(options).get('selected_options', {}) self.verbose = dict(options).get('verbose', False) self.cache_files = dict(options).get('cache_files', True) def get_tag_choice(self, parsed_tag, tags): if self.selected_options.get(parsed_tag.lower()) is not None: return tags[self.selected_options.get(parsed_tag.lower())] return choices(tags)[0] if len(tags) > 0 else "" def get_tag_group_choice(self, parsed_tag, groups, tags): #print('selected_options', self.selected_options) #print('groups', groups) #print('parsed_tag', parsed_tag) neg_groups = [x.strip().lower() for x in groups if x.startswith('--')] neg_groups_set = {x.replace('--', '') for x in neg_groups} any_groups = [{y.strip() for i, y in enumerate(x.lower().split('|'))} for x in groups if '|' in x] pos_groups = [ x.strip().lower() for x in groups if not x.startswith('--') and '|' not in x ] pos_groups_set = {x for x in pos_groups} # print('pos_groups', pos_groups_set) # print('negative_groups', neg_groups_set) # print('any_groups', any_groups) candidates = [] for tag in tags: tag_set = tags[tag] if len(list(pos_groups_set & tag_set)) != len(pos_groups_set): continue if len(list(neg_groups_set & tag_set)) > 0: continue if len(any_groups) > 0: any_groups_found = 0 for any_group in any_groups: if len(list(any_group & tag_set)) == 0: break any_groups_found += 1 if len(any_groups) != any_groups_found: continue candidates.append(tag) if len(candidates) > 0: if self.verbose: print( f'UmiAI: Found {len(candidates)} candidates for "{parsed_tag}" with tags: {groups}, first 10: {candidates[:10]}' ) return choices(candidates)[0] print( f'UmiAI: No tag candidates found for: "{parsed_tag}" with tags: {groups}' ) return "" def select(self, tag, groups=None): self.previously_selected_tags.setdefault(tag, 0) if tag.count(':')==2: return False if self.previously_selected_tags.get(tag) < 100: self.previously_selected_tags[tag] += 1 parsed_tag = parse_tag(tag) tags = self.tag_loader.load_tags(parsed_tag, self.verbose, self.cache_files) if groups and len(groups) > 0: return self.get_tag_group_choice(parsed_tag, groups, tags) if len(tags) > 0: return self.get_tag_choice(parsed_tag, tags) else: print( f'UmiAI: No tags found in wildcard file "{parsed_tag}" or file does not exist' ) return False print(f'loaded tag more than 100 times {parsed_tag}') return False class TagReplacer: def __init__(self, tag_selector, options): self.tag_selector = tag_selector self.options = options self.wildcard_regex = re.compile('((__|<)(.*?)(__|>))') self.opts_regexp = re.compile('(?<=\[)(.*?)(?=\])') def replace_wildcard(self, matches): if matches is None or len(matches.groups()) == 0: return "" match = matches.groups()[2] match_and_opts = match.split(':') if (len(match_and_opts) == 2): selected_tags = self.tag_selector.select( match_and_opts[0], self.opts_regexp.findall(match_and_opts[1])) else: global_opts = self.opts_regexp.findall(match) if len(global_opts) > 0: selected_tags = self.tag_selector.select(ALL_KEY, global_opts) else: selected_tags = self.tag_selector.select(match) if selected_tags: return selected_tags return matches[0] def replace_wildcard_recursive(self, prompt): p = self.wildcard_regex.sub(self.replace_wildcard, prompt) while p != prompt: prompt = p p = self.wildcard_regex.sub(self.replace_wildcard, prompt) return p def replace(self, prompt): return self.replace_wildcard_recursive(prompt) # handle {1$$this | that} notation class DynamicPromptReplacer: def __init__(self): self.re_combinations = re.compile(r"\{([^{}]*)}") def get_variant_weight(self, variant): split_variant = variant.split("%") if len(split_variant) == 2: num = split_variant[0] try: return int(num) except ValueError: print(f'{num} is not a number') return 0 def get_variant(self, variant): split_variant = variant.split("%") if len(split_variant) == 2: return split_variant[1] return variant def parse_range(self, range_str, num_variants): if range_str is None: return None parts = range_str.split("-") if len(parts) == 1: low = high = min(int(parts[0]), num_variants) elif len(parts) == 2: low = int(parts[0]) if parts[0] else 0 high = min(int(parts[1]), num_variants) if parts[1] else num_variants else: raise Exception(f"Unexpected range {range_str}") return min(low, high), max(low, high) def replace_combinations(self, match): if match is None or len(match.groups()) == 0: return "" combinations_str = match.groups()[0] variants = [s.strip() for s in combinations_str.split("|")] weights = [self.get_variant_weight(var) for var in variants] variants = [self.get_variant(var) for var in variants] splits = variants[0].split("$$") quantity = splits.pop(0) if len(splits) > 1 else str(1) variants[0] = splits[0] low_range, high_range = self.parse_range(quantity, len(variants)) quantity = random.randint(low_range, high_range) summed = sum(weights) zero_weights = weights.count(0) weights = list( map(lambda x: (100 - summed) / zero_weights if x == 0 else x, weights)) try: #print(f"choosing {quantity} tag from:\n{' , '.join(variants)}") picked = [] for x in range(quantity): choice = random.choices(variants, weights)[0] picked.append(choice) index = variants.index(choice) variants.pop(index) weights.pop(index) #print(f"Picked:\n{' , '.join(picked)}\n") return " , ".join(picked) except ValueError as e: return "" def replace(self, template): if template is None: return None return self.re_combinations.sub(self.replace_combinations, template) class OptionGenerator: def __init__(self, tag_loader): self.tag_loader = tag_loader def get_configurable_options(self): return self.tag_loader.load_tags('configuration') def get_option_choices(self, tag): return self.tag_loader.load_tags(parse_tag(tag)) def parse_options(self, options): tag_presets = {} for i, tag in enumerate(self.get_configurable_options()): parsed_tag = parse_tag(tag) location = get_index(self.tag_loader.load_tags(parsed_tag), options[i]) if location is not None: tag_presets[parsed_tag.lower()] = location return tag_presets class PromptGenerator: def __init__(self, options): self.tag_loader = TagLoader() self.tag_selector = TagSelector(self.tag_loader, options) self.negative_tag_generator = NegativePromptGenerator() self.settings_generator = SettingsGenerator() self.replacers = [ self.settings_generator, TagReplacer(self.tag_selector, options), DynamicPromptReplacer(), self.negative_tag_generator ] self.verbose = dict(options).get('verbose', False) def use_replacers(self, prompt): for replacer in self.replacers: prompt = replacer.replace(prompt) return prompt def generate_single_prompt(self, original_prompt): previous_prompt = original_prompt start = time.time() prompt = self.use_replacers(original_prompt) while previous_prompt != prompt: previous_prompt = prompt prompt = self.use_replacers(prompt) end = time.time() if self.verbose: print(f"Prompt generated in {end - start} seconds") return prompt def get_negative_tags(self): return self.negative_tag_generator.get_negative_tags() def get_setting_overrides(self): return self.settings_generator.get_setting_overrides() class NegativePromptGenerator: def __init__(self): self.negative_tag = set() def strip_negative_tags(self, tags): matches = re.findall('\*\*.*?\*\*', tags) if matches: for match in matches: self.negative_tag.add(match.replace("**", "")) tags = tags.replace(match, "") return tags def replace(self, prompt): return self.strip_negative_tags(prompt) def get_negative_tags(self): return " ".join(self.negative_tag) # @@settings@@ notation class SettingsGenerator: def __init__(self): self.re_setting_tags = re.compile(r"@@(.*?)@@") self.setting_overrides = {} self.type_mapping = { 'cfg_scale': float, 'sampler': str, 'steps': int, } def strip_setting_tags(self, prompt): matches = self.re_setting_tags.findall(prompt) if matches: for match in matches: for assignment in match.split("|"): key_raw, value = assignment.split("=") if not value: print( f"Invalid setting {assignment}, settings should assign a value" ) continue key_found = False for key in self.type_mapping.keys(): if key.startswith(key_raw): self.setting_overrides[key] = self.type_mapping[ key](value) key_found = True break if not key_found: print( f"Unknown setting {key_raw}, setting should be the starting part of: {', '.join(self.type_mapping.keys())}" ) prompt = prompt.replace('@@' + match + '@@', "") return prompt def replace(self, prompt): return self.strip_setting_tags(prompt) def get_setting_overrides(self): return self.setting_overrides class Script(scripts.Script): is_txt2img = False embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() def __init__(self): pass embedding_dir = os.path.join( pathlib.Path(inspect.getfile(lambda: None)).parent.parent, "embeddings") self.embedding_db.add_embedding_dir(embedding_dir) self.embedding_db.load_textual_inversion_embeddings(force_reload=True) def title(self): return "Prompt generator" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): self.is_txt2img = is_img2img == False with gr.Group(): with gr.Row(): enabled = gr.Checkbox(label="UmiAI enabled", value=True) verbose = gr.Checkbox(label="Verbose logging", value=False) cache_files = gr.Checkbox(label="Cache files", value=True) same_seed = gr.Checkbox(label='Same prompt in batch', value=False) negative_prompt = gr.Checkbox(label='**negative keywords**', value=True) shared_seed = gr.Checkbox(label="Static wildcards", value=False) option_generator = OptionGenerator(TagLoader()) options = [ gr.Dropdown(label=opt, choices=["RANDOM"] + option_generator.get_option_choices(opt), value="RANDOM") for opt in option_generator.get_configurable_options() ] return [enabled, verbose, cache_files, same_seed, negative_prompt, shared_seed ] + options def process(self, p, enabled, verbose, cache_files, same_seed, negative_prompt, shared_seed, *args): if not enabled: return debug = False if debug: print(f'\nModel: {p.sampler_name}, Seed: {int(p.seed)}, Batch Count: {p.n_iter}, Batch Size: {p.batch_size}, CFG: {p.cfg_scale}, Steps: {p.steps}\nOriginal Prompt: "{p.prompt}"\nOriginal Negatives: "{p.negative_prompt}"\n') original_prompt = p.all_prompts[0] if hasattr(p, "all_negative_prompts"): # hasattr to fix crash on old webui versions original_negative = p.all_negative_prompts[0] else: original_negative = "" TagLoader.files.clear() original_prompt = p.all_prompts[0] option_generator = OptionGenerator(TagLoader()) options = { 'selected_options': option_generator.parse_options(args), 'verbose': verbose, 'cache_files': cache_files, } prompt_generator = PromptGenerator(options) for cur_count in range(p.n_iter): #Batch count for cur_batch in range(p.batch_size): #Batch Size index = p.batch_size * cur_count + cur_batch # pick same wildcard for a given seed if (shared_seed): random.seed(p.all_seeds[p.batch_size *cur_count if same_seed else index]) else: random.seed(time.time()) if debug: print(f'{"Batch #"+str(cur_count) if same_seed else "Prompt #"+str(index):=^30}') prompt_generator.negative_tag_generator.negative_tag = set() prompt = prompt_generator.generate_single_prompt(original_prompt) p.all_prompts[index] = prompt if debug: print(f'Prompt: "{prompt}"') negative = original_negative if negative_prompt and hasattr(p, "all_negative_prompts"): # hasattr to fix crash on old webui versions negative += prompt_generator.get_negative_tags() p.all_negative_prompts[index] = negative if debug: print(f'Negative: "{negative}\n"') # same prompt per batch if (same_seed): for index in range(index, index + p.batch_size): p.all_prompts[index] = prompt break def find_sampler_index(sampler_list, value): for index, elem in enumerate(sampler_list): if elem[0] == value or value in elem[2]: return index att_override = prompt_generator.get_setting_overrides() #print(att_override) for att in att_override.keys(): if not att.startswith("__"): if att == 'sampler': sampler_name = att_override[att] if self.is_txt2img: sampler_index = find_sampler_index( samplers, sampler_name) else: sampler_index = find_sampler_index( samplers_for_img2img, sampler_name) if (sampler_index != None): setattr(p, 'sampler_index', sampler_index) else: print( f"Sampler {sampler_name} not found in prompt {p.all_prompts[0]}" ) continue setattr(p, att, att_override[att]) if original_prompt != p.all_prompts[0]: p.extra_generation_params["Wildcard prompt"] = original_prompt if verbose: p.extra_generation_params["File includes"] = "|".join( TagLoader.files) from modules import sd_hijack path = os.path.join(scripts.basedir(), "embeddings") sd_hijack.model_hijack.embedding_db.add_embedding_dir(path)