Stable-Diffusion / Umi-AI /scripts /wildcard_recursive.py
RudiCahyan's picture
Upload 1119 files
cb56a8d
import os
import random
import inspect
import pathlib
import re
import time
import glob
from random import choices
import yaml
import modules.scripts as scripts
import modules.images as images
import gradio as gr
from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state
from modules import scripts, script_callbacks, shared
from modules.styles import StyleDatabase
import modules.textual_inversion.textual_inversion
from modules.sd_samplers import samplers, samplers_for_img2img
ALL_KEY = 'all yaml files'
def get_index(items, item):
try:
return items.index(item)
except Exception:
return None
def parse_tag(tag):
return tag.replace("__", "").replace('<', '').replace('>', '').strip()
def read_file_lines(file):
f_lines = file.read().splitlines()
lines = []
for line in f_lines:
line = line.strip()
# Check if we have a line that is not empty or starts with an #
if line and not line.startswith('#'):
lines.append(line)
return lines
# Wildcards
class TagLoader:
files = []
wildcard_location = os.path.join(
pathlib.Path(inspect.getfile(lambda: None)).parent.parent, "wildcards")
loaded_tags = {}
missing_tags = set()
def load_tags(self, file_path, verbose=False, cache_files=True):
if cache_files and self.loaded_tags.get(file_path):
return self.loaded_tags.get(file_path)
txt_file_path = os.path.join(self.wildcard_location, f'{file_path}.txt')
yaml_file_path = os.path.join(self.wildcard_location,
f'{file_path}.yaml')
if (file_path == ALL_KEY):
key = ALL_KEY
else:
key = file_path.lower()
if self.wildcard_location and os.path.isfile(txt_file_path):
with open(txt_file_path, encoding="utf8") as file:
self.files.append(f"{file_path}.txt")
self.loaded_tags[key] = read_file_lines(file)
if key is ALL_KEY and self.wildcard_location:
files = glob.glob(os.path.join(self.wildcard_location, '**/*.yaml'), recursive=True)
output = {}
for file in files:
with open(file, encoding="utf8") as file:
self.files.append(f"{file_path}.yaml")
try:
data = yaml.safe_load(file)
for item in data:
if (hasattr(output, item) and verbose):
print(f"Duplicate key {item} in {file}")
output[item] = {
x.lower().strip()
for i, x in enumerate(data[item]['Tags'])
}
except yaml.YAMLError as exc:
print(exc)
self.loaded_tags[key] = output
if self.wildcard_location and os.path.isfile(yaml_file_path):
with open(yaml_file_path, encoding="utf8") as file:
self.files.append(f"{file_path}.yaml")
try:
data = yaml.safe_load(file)
output = {}
for item in data:
output[item] = {
x.lower().strip()
for i, x in enumerate(data[item]['Tags'])
}
self.loaded_tags[key] = output
except yaml.YAMLError as exc:
print(exc)
if not os.path.isfile(yaml_file_path) and not os.path.isfile(
txt_file_path):
self.missing_tags.add(file_path)
return self.loaded_tags.get(key) if self.loaded_tags.get(
key) else []
# <yaml:[tag]> notation
class TagSelector:
def __init__(self, tag_loader, options):
self.tag_loader = tag_loader
self.previously_selected_tags = {}
self.selected_options = dict(options).get('selected_options', {})
self.verbose = dict(options).get('verbose', False)
self.cache_files = dict(options).get('cache_files', True)
def get_tag_choice(self, parsed_tag, tags):
if self.selected_options.get(parsed_tag.lower()) is not None:
return tags[self.selected_options.get(parsed_tag.lower())]
return choices(tags)[0] if len(tags) > 0 else ""
def get_tag_group_choice(self, parsed_tag, groups, tags):
#print('selected_options', self.selected_options)
#print('groups', groups)
#print('parsed_tag', parsed_tag)
neg_groups = [x.strip().lower() for x in groups if x.startswith('--')]
neg_groups_set = {x.replace('--', '') for x in neg_groups}
any_groups = [{y.strip()
for i, y in enumerate(x.lower().split('|'))}
for x in groups if '|' in x]
pos_groups = [
x.strip().lower() for x in groups
if not x.startswith('--') and '|' not in x
]
pos_groups_set = {x for x in pos_groups}
# print('pos_groups', pos_groups_set)
# print('negative_groups', neg_groups_set)
# print('any_groups', any_groups)
candidates = []
for tag in tags:
tag_set = tags[tag]
if len(list(pos_groups_set & tag_set)) != len(pos_groups_set):
continue
if len(list(neg_groups_set & tag_set)) > 0:
continue
if len(any_groups) > 0:
any_groups_found = 0
for any_group in any_groups:
if len(list(any_group & tag_set)) == 0:
break
any_groups_found += 1
if len(any_groups) != any_groups_found:
continue
candidates.append(tag)
if len(candidates) > 0:
if self.verbose:
print(
f'UmiAI: Found {len(candidates)} candidates for "{parsed_tag}" with tags: {groups}, first 10: {candidates[:10]}'
)
return choices(candidates)[0]
print(
f'UmiAI: No tag candidates found for: "{parsed_tag}" with tags: {groups}'
)
return ""
def select(self, tag, groups=None):
self.previously_selected_tags.setdefault(tag, 0)
if tag.count(':')==2:
return False
if self.previously_selected_tags.get(tag) < 100:
self.previously_selected_tags[tag] += 1
parsed_tag = parse_tag(tag)
tags = self.tag_loader.load_tags(parsed_tag, self.verbose, self.cache_files)
if groups and len(groups) > 0:
return self.get_tag_group_choice(parsed_tag, groups, tags)
if len(tags) > 0:
return self.get_tag_choice(parsed_tag, tags)
else:
print(
f'UmiAI: No tags found in wildcard file "{parsed_tag}" or file does not exist'
)
return False
print(f'loaded tag more than 100 times {parsed_tag}')
return False
class TagReplacer:
def __init__(self, tag_selector, options):
self.tag_selector = tag_selector
self.options = options
self.wildcard_regex = re.compile('((__|<)(.*?)(__|>))')
self.opts_regexp = re.compile('(?<=\[)(.*?)(?=\])')
def replace_wildcard(self, matches):
if matches is None or len(matches.groups()) == 0:
return ""
match = matches.groups()[2]
match_and_opts = match.split(':')
if (len(match_and_opts) == 2):
selected_tags = self.tag_selector.select(
match_and_opts[0], self.opts_regexp.findall(match_and_opts[1]))
else:
global_opts = self.opts_regexp.findall(match)
if len(global_opts) > 0:
selected_tags = self.tag_selector.select(ALL_KEY, global_opts)
else:
selected_tags = self.tag_selector.select(match)
if selected_tags:
return selected_tags
return matches[0]
def replace_wildcard_recursive(self, prompt):
p = self.wildcard_regex.sub(self.replace_wildcard, prompt)
while p != prompt:
prompt = p
p = self.wildcard_regex.sub(self.replace_wildcard, prompt)
return p
def replace(self, prompt):
return self.replace_wildcard_recursive(prompt)
# handle {1$$this | that} notation
class DynamicPromptReplacer:
def __init__(self):
self.re_combinations = re.compile(r"\{([^{}]*)}")
def get_variant_weight(self, variant):
split_variant = variant.split("%")
if len(split_variant) == 2:
num = split_variant[0]
try:
return int(num)
except ValueError:
print(f'{num} is not a number')
return 0
def get_variant(self, variant):
split_variant = variant.split("%")
if len(split_variant) == 2:
return split_variant[1]
return variant
def parse_range(self, range_str, num_variants):
if range_str is None:
return None
parts = range_str.split("-")
if len(parts) == 1:
low = high = min(int(parts[0]), num_variants)
elif len(parts) == 2:
low = int(parts[0]) if parts[0] else 0
high = min(int(parts[1]),
num_variants) if parts[1] else num_variants
else:
raise Exception(f"Unexpected range {range_str}")
return min(low, high), max(low, high)
def replace_combinations(self, match):
if match is None or len(match.groups()) == 0:
return ""
combinations_str = match.groups()[0]
variants = [s.strip() for s in combinations_str.split("|")]
weights = [self.get_variant_weight(var) for var in variants]
variants = [self.get_variant(var) for var in variants]
splits = variants[0].split("$$")
quantity = splits.pop(0) if len(splits) > 1 else str(1)
variants[0] = splits[0]
low_range, high_range = self.parse_range(quantity, len(variants))
quantity = random.randint(low_range, high_range)
summed = sum(weights)
zero_weights = weights.count(0)
weights = list(
map(lambda x: (100 - summed) / zero_weights
if x == 0 else x, weights))
try:
#print(f"choosing {quantity} tag from:\n{' , '.join(variants)}")
picked = []
for x in range(quantity):
choice = random.choices(variants, weights)[0]
picked.append(choice)
index = variants.index(choice)
variants.pop(index)
weights.pop(index)
#print(f"Picked:\n{' , '.join(picked)}\n")
return " , ".join(picked)
except ValueError as e:
return ""
def replace(self, template):
if template is None:
return None
return self.re_combinations.sub(self.replace_combinations, template)
class OptionGenerator:
def __init__(self, tag_loader):
self.tag_loader = tag_loader
def get_configurable_options(self):
return self.tag_loader.load_tags('configuration')
def get_option_choices(self, tag):
return self.tag_loader.load_tags(parse_tag(tag))
def parse_options(self, options):
tag_presets = {}
for i, tag in enumerate(self.get_configurable_options()):
parsed_tag = parse_tag(tag)
location = get_index(self.tag_loader.load_tags(parsed_tag),
options[i])
if location is not None:
tag_presets[parsed_tag.lower()] = location
return tag_presets
class PromptGenerator:
def __init__(self, options):
self.tag_loader = TagLoader()
self.tag_selector = TagSelector(self.tag_loader, options)
self.negative_tag_generator = NegativePromptGenerator()
self.settings_generator = SettingsGenerator()
self.replacers = [
self.settings_generator,
TagReplacer(self.tag_selector, options),
DynamicPromptReplacer(),
self.negative_tag_generator
]
self.verbose = dict(options).get('verbose', False)
def use_replacers(self, prompt):
for replacer in self.replacers:
prompt = replacer.replace(prompt)
return prompt
def generate_single_prompt(self, original_prompt):
previous_prompt = original_prompt
start = time.time()
prompt = self.use_replacers(original_prompt)
while previous_prompt != prompt:
previous_prompt = prompt
prompt = self.use_replacers(prompt)
end = time.time()
if self.verbose:
print(f"Prompt generated in {end - start} seconds")
return prompt
def get_negative_tags(self):
return self.negative_tag_generator.get_negative_tags()
def get_setting_overrides(self):
return self.settings_generator.get_setting_overrides()
class NegativePromptGenerator:
def __init__(self):
self.negative_tag = set()
def strip_negative_tags(self, tags):
matches = re.findall('\*\*.*?\*\*', tags)
if matches:
for match in matches:
self.negative_tag.add(match.replace("**", ""))
tags = tags.replace(match, "")
return tags
def replace(self, prompt):
return self.strip_negative_tags(prompt)
def get_negative_tags(self):
return " ".join(self.negative_tag)
# @@settings@@ notation
class SettingsGenerator:
def __init__(self):
self.re_setting_tags = re.compile(r"@@(.*?)@@")
self.setting_overrides = {}
self.type_mapping = {
'cfg_scale': float,
'sampler': str,
'steps': int,
}
def strip_setting_tags(self, prompt):
matches = self.re_setting_tags.findall(prompt)
if matches:
for match in matches:
for assignment in match.split("|"):
key_raw, value = assignment.split("=")
if not value:
print(
f"Invalid setting {assignment}, settings should assign a value"
)
continue
key_found = False
for key in self.type_mapping.keys():
if key.startswith(key_raw):
self.setting_overrides[key] = self.type_mapping[
key](value)
key_found = True
break
if not key_found:
print(
f"Unknown setting {key_raw}, setting should be the starting part of: {', '.join(self.type_mapping.keys())}"
)
prompt = prompt.replace('@@' + match + '@@', "")
return prompt
def replace(self, prompt):
return self.strip_setting_tags(prompt)
def get_setting_overrides(self):
return self.setting_overrides
class Script(scripts.Script):
is_txt2img = False
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
def __init__(self):
pass
embedding_dir = os.path.join(
pathlib.Path(inspect.getfile(lambda: None)).parent.parent, "embeddings")
self.embedding_db.add_embedding_dir(embedding_dir)
self.embedding_db.load_textual_inversion_embeddings(force_reload=True)
def title(self):
return "Prompt generator"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.is_txt2img = is_img2img == False
with gr.Group():
with gr.Row():
enabled = gr.Checkbox(label="UmiAI enabled", value=True)
verbose = gr.Checkbox(label="Verbose logging", value=False)
cache_files = gr.Checkbox(label="Cache files", value=True)
same_seed = gr.Checkbox(label='Same prompt in batch',
value=False)
negative_prompt = gr.Checkbox(label='**negative keywords**',
value=True)
shared_seed = gr.Checkbox(label="Static wildcards",
value=False)
option_generator = OptionGenerator(TagLoader())
options = [
gr.Dropdown(label=opt,
choices=["RANDOM"] +
option_generator.get_option_choices(opt),
value="RANDOM")
for opt in option_generator.get_configurable_options()
]
return [enabled, verbose, cache_files, same_seed, negative_prompt, shared_seed
] + options
def process(self, p, enabled, verbose, cache_files, same_seed, negative_prompt,
shared_seed, *args):
if not enabled:
return
debug = False
if debug: print(f'\nModel: {p.sampler_name}, Seed: {int(p.seed)}, Batch Count: {p.n_iter}, Batch Size: {p.batch_size}, CFG: {p.cfg_scale}, Steps: {p.steps}\nOriginal Prompt: "{p.prompt}"\nOriginal Negatives: "{p.negative_prompt}"\n')
original_prompt = p.all_prompts[0]
if hasattr(p, "all_negative_prompts"): # hasattr to fix crash on old webui versions
original_negative = p.all_negative_prompts[0]
else:
original_negative = ""
TagLoader.files.clear()
original_prompt = p.all_prompts[0]
option_generator = OptionGenerator(TagLoader())
options = {
'selected_options': option_generator.parse_options(args),
'verbose': verbose,
'cache_files': cache_files,
}
prompt_generator = PromptGenerator(options)
for cur_count in range(p.n_iter): #Batch count
for cur_batch in range(p.batch_size): #Batch Size
index = p.batch_size * cur_count + cur_batch
# pick same wildcard for a given seed
if (shared_seed):
random.seed(p.all_seeds[p.batch_size *cur_count if same_seed else index])
else:
random.seed(time.time())
if debug: print(f'{"Batch #"+str(cur_count) if same_seed else "Prompt #"+str(index):=^30}')
prompt_generator.negative_tag_generator.negative_tag = set()
prompt = prompt_generator.generate_single_prompt(original_prompt)
p.all_prompts[index] = prompt
if debug: print(f'Prompt: "{prompt}"')
negative = original_negative
if negative_prompt and hasattr(p, "all_negative_prompts"): # hasattr to fix crash on old webui versions
negative += prompt_generator.get_negative_tags()
p.all_negative_prompts[index] = negative
if debug: print(f'Negative: "{negative}\n"')
# same prompt per batch
if (same_seed):
for index in range(index, index + p.batch_size):
p.all_prompts[index] = prompt
break
def find_sampler_index(sampler_list, value):
for index, elem in enumerate(sampler_list):
if elem[0] == value or value in elem[2]:
return index
att_override = prompt_generator.get_setting_overrides()
#print(att_override)
for att in att_override.keys():
if not att.startswith("__"):
if att == 'sampler':
sampler_name = att_override[att]
if self.is_txt2img:
sampler_index = find_sampler_index(
samplers, sampler_name)
else:
sampler_index = find_sampler_index(
samplers_for_img2img, sampler_name)
if (sampler_index != None):
setattr(p, 'sampler_index', sampler_index)
else:
print(
f"Sampler {sampler_name} not found in prompt {p.all_prompts[0]}"
)
continue
setattr(p, att, att_override[att])
if original_prompt != p.all_prompts[0]:
p.extra_generation_params["Wildcard prompt"] = original_prompt
if verbose:
p.extra_generation_params["File includes"] = "|".join(
TagLoader.files)
from modules import sd_hijack
path = os.path.join(scripts.basedir(), "embeddings")
sd_hijack.model_hijack.embedding_db.add_embedding_dir(path)