Spaces:
Runtime error
Runtime error
import io | |
import re | |
import os | |
import sys | |
import math | |
import json | |
import uuid | |
import queue | |
import string | |
import random | |
import hashlib | |
import datetime | |
import threading | |
from pathlib import Path | |
from collections import namedtuple | |
import numpy as np | |
import piexif | |
import piexif.helper | |
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ExifTags | |
from modules import sd_samplers, shared, script_callbacks, errors, paths | |
debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None | |
try: | |
from pi_heif import register_heif_opener | |
register_heif_opener() | |
except Exception: | |
pass | |
def check_grid_size(imgs): | |
mp = 0 | |
for img in imgs: | |
mp += img.width * img.height | |
mp = round(mp / 1000000) | |
ok = mp <= shared.opts.img_max_size_mp | |
if not ok: | |
shared.log.warning(f'Maximum image size exceded: size={mp} maximum={shared.opts.img_max_size_mp} MPixels') | |
return ok | |
def image_grid(imgs, batch_size=1, rows=None): | |
if rows is None: | |
if shared.opts.n_rows > 0: | |
rows = shared.opts.n_rows | |
elif shared.opts.n_rows == 0: | |
rows = batch_size | |
else: | |
rows = math.floor(math.sqrt(len(imgs))) | |
while len(imgs) % rows != 0: | |
rows -= 1 | |
if rows > len(imgs): | |
rows = len(imgs) | |
cols = math.ceil(len(imgs) / rows) | |
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows) | |
script_callbacks.image_grid_callback(params) | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=shared.opts.grid_background) | |
for i, img in enumerate(params.imgs): | |
grid.paste(img, box=(i % params.cols * w, i // params.cols * h)) | |
return grid | |
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) | |
def split_grid(image, tile_w=512, tile_h=512, overlap=64): | |
w = image.width | |
h = image.height | |
non_overlap_width = tile_w - overlap | |
non_overlap_height = tile_h - overlap | |
cols = math.ceil((w - overlap) / non_overlap_width) | |
rows = math.ceil((h - overlap) / non_overlap_height) | |
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 | |
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 | |
grid = Grid([], tile_w, tile_h, w, h, overlap) | |
for row in range(rows): | |
row_images = [] | |
y = int(row * dy) | |
if y + tile_h >= h: | |
y = h - tile_h | |
for col in range(cols): | |
x = int(col * dx) | |
if x + tile_w >= w: | |
x = w - tile_w | |
tile = image.crop((x, y, x + tile_w, y + tile_h)) | |
row_images.append([x, tile_w, tile]) | |
grid.tiles.append([y, tile_h, row_images]) | |
return grid | |
def combine_grid(grid): | |
def make_mask_image(r): | |
r = r * 255 / grid.overlap | |
r = r.astype(np.uint8) | |
return Image.fromarray(r, 'L') | |
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) | |
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) | |
combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) | |
for y, h, row in grid.tiles: | |
combined_row = Image.new("RGB", (grid.image_w, h)) | |
for x, w, tile in row: | |
if x == 0: | |
combined_row.paste(tile, (0, 0)) | |
continue | |
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) | |
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) | |
if y == 0: | |
combined_image.paste(combined_row, (0, 0)) | |
continue | |
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h) | |
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap)) | |
return combined_image | |
class GridAnnotation: | |
def __init__(self, text='', is_active=True): | |
self.text = text | |
self.is_active = is_active | |
self.size = None | |
def get_font(fontsize): | |
try: | |
return ImageFont.truetype(shared.opts.font or "javascript/notosans-nerdfont-regular.ttf", fontsize) | |
except Exception: | |
return ImageFont.truetype("javascript/notosans-nerdfont-regular.ttf", fontsize) | |
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0, title=None): | |
def wrap(drawing, text, font, line_length): | |
lines = [''] | |
for word in text.split(): | |
line = f'{lines[-1]} {word}'.strip() | |
if drawing.textlength(line, font=font) <= line_length: | |
lines[-1] = line | |
else: | |
lines.append(word) | |
return lines | |
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize): | |
for line in lines: | |
font = initial_fnt | |
fontsize = initial_fontsize | |
while drawing.multiline_textbbox((0,0), text=line.text, font=font)[2] > line.allowed_width and fontsize > 0: | |
fontsize -= 1 | |
font = get_font(fontsize) | |
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=font, fill=shared.opts.font_color if line.is_active else color_inactive, anchor="mm", align="center") | |
if not line.is_active: | |
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4) | |
draw_y += line.size[1] + line_spacing | |
fontsize = (width + height) // 25 | |
line_spacing = fontsize // 2 | |
font = get_font(fontsize) | |
color_inactive = (127, 127, 127) | |
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4 | |
cols = im.width // width | |
rows = im.height // height | |
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}' | |
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}' | |
calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background) | |
calc_d = ImageDraw.Draw(calc_img) | |
title_texts = [title] if title else [[GridAnnotation()]] | |
for texts, allowed_width in zip(hor_texts + ver_texts + title_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts) + [(width+margin)*cols]): | |
items = [] + texts | |
texts.clear() | |
for line in items: | |
wrapped = wrap(calc_d, line.text, font, allowed_width) | |
texts += [GridAnnotation(x, line.is_active) for x in wrapped] | |
for line in texts: | |
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=font) | |
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1]) | |
line.allowed_width = allowed_width | |
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts] | |
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] | |
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 | |
title_pad = 0 | |
if title: | |
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object | |
title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2 | |
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background) | |
for row in range(rows): | |
for col in range(cols): | |
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1))) | |
result.paste(cell, (pad_left + (width + margin) * col, pad_top + title_pad + (height + margin) * row)) | |
d = ImageDraw.Draw(result) | |
if title: | |
x = pad_left + ((width+margin)*cols) / 2 | |
y = title_pad / 2 - title_text_heights[0] / 2 | |
draw_texts(d, x, y, title_texts[0], font, fontsize) | |
for col in range(cols): | |
x = pad_left + (width + margin) * col + width / 2 | |
y = (pad_top / 2 - hor_text_heights[col] / 2) + title_pad | |
draw_texts(d, x, y, hor_texts[col], font, fontsize) | |
for row in range(rows): | |
x = pad_left / 2 | |
y = (pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2) + title_pad | |
draw_texts(d, x, y, ver_texts[row], font, fontsize) | |
return result | |
def draw_prompt_matrix(im, width, height, all_prompts, margin=0): | |
prompts = all_prompts[1:] | |
boundary = math.ceil(len(prompts) / 2) | |
prompts_horiz = prompts[:boundary] | |
prompts_vert = prompts[boundary:] | |
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))] | |
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))] | |
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin) | |
def resize_image(resize_mode, im, width, height, upscaler_name=None, output_type='image'): | |
if im.width == width and im.height == height: | |
shared.log.debug(f'Image resize: input={im} target={width}x{height} mode={shared.resize_modes[resize_mode]} upscaler="{upscaler_name}" fn={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access | |
upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img | |
def latent(im, w, h, upscaler): | |
from modules.processing_vae import vae_encode, vae_decode | |
import torch | |
latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO enable full VAE mode | |
latents = torch.nn.functional.interpolate(latents, size=(int(h // 8), int(w // 8)), mode=upscaler["mode"], antialias=upscaler["antialias"]) | |
im = vae_decode(latents, shared.sd_model, output_type='pil', full_quality=False)[0] | |
return im | |
def resize(im, w, h): | |
w = int(w) | |
h = int(h) | |
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L': | |
return im.resize((w, h), resample=Image.Resampling.LANCZOS) # force for mask | |
scale = max(w / im.width, h / im.height) | |
if scale > 1.0: | |
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] | |
if len(upscalers) > 0: | |
upscaler = upscalers[0] | |
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) | |
else: | |
upscaler = shared.latent_upscale_modes.get(upscaler_name, None) | |
if upscaler is not None: | |
im = latent(im, w, h, upscaler) | |
else: | |
upscaler = upscalers[0] | |
shared.log.warning(f"Resize upscaler: invalid={upscaler_name} fallback={upscaler.name}") | |
if im.width != w or im.height != h: # probably downsample after upscaler created larger image | |
im = im.resize((w, h), resample=Image.Resampling.LANCZOS) | |
return im | |
def crop(im): | |
ratio = width / height | |
src_ratio = im.width / im.height | |
src_w = width if ratio > src_ratio else im.width * height // im.height | |
src_h = height if ratio <= src_ratio else im.height * width // im.width | |
resized = resize(im, src_w, src_h) | |
res = Image.new(im.mode, (width, height)) | |
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) | |
return res | |
def fill(im, color=None): | |
color = color or shared.opts.image_background | |
""" | |
ratio = round(width / height, 1) | |
src_ratio = round(im.width / im.height, 1) | |
src_w = width if ratio < src_ratio else im.width * height // im.height | |
src_h = height if ratio >= src_ratio else im.height * width // im.width | |
resized = resize(im, src_w, src_h) | |
res = Image.new(im.mode, (width, height)) | |
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) | |
if ratio < src_ratio: | |
fill_height = height // 2 - src_h // 2 | |
if width > 0 and fill_height > 0: | |
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) | |
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) | |
elif ratio > src_ratio: | |
fill_width = width // 2 - src_w // 2 | |
if height > 0 and fill_width > 0: | |
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) | |
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) | |
return res | |
""" | |
ratio = min(width / im.width, height / im.height) | |
im = resize(im, int(im.width * ratio), int(im.height * ratio)) | |
res = Image.new(im.mode, (width, height), color=color) | |
res.paste(im, box=((width - im.width)//2, (height - im.height)//2)) | |
return res | |
if resize_mode == 0 or (im.width == width and im.height == height): # none | |
res = im.copy() | |
elif resize_mode == 1: # fixed | |
res = resize(im, width, height) | |
elif resize_mode == 2: # crop | |
res = crop(im) | |
elif resize_mode == 3: # fill | |
res = fill(im) | |
elif resize_mode == 4: # edge | |
from modules import masking | |
res = fill(im, color=0) | |
res, _mask = masking.outpaint(res) | |
if output_type == 'np': | |
return np.array(res) | |
return res | |
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') | |
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)") | |
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$") | |
re_attention = re.compile(r'[\(*\[*](\w+)(:\d+(\.\d+))?[\)*\]*]|') | |
re_network = re.compile(r'\<\w+:(\w+)(:\d+(\.\d+))?\>|') | |
re_brackets = re.compile(r'[\([{})\]]') | |
NOTHING = object() | |
class FilenameGenerator: | |
replacements = { | |
'width': lambda self: self.image.width, | |
'height': lambda self: self.image.height, | |
'batch_number': lambda self: self.batch_number, | |
'iter_number': lambda self: self.iter_number, | |
'num': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, | |
'generation_number': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, | |
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), | |
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] | |
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..] | |
'hash': lambda self: self.image_hash(), | |
'image_hash': lambda self: self.image_hash(), | |
'timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), | |
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), | |
'model': lambda self: shared.sd_model.sd_checkpoint_info.title, | |
'model_shortname': lambda self: shared.sd_model.sd_checkpoint_info.model_name, | |
'model_name': lambda self: shared.sd_model.sd_checkpoint_info.model_name, | |
'model_hash': lambda self: shared.sd_model.sd_checkpoint_info.shorthash, | |
'prompt': lambda self: self.prompt_full(), | |
'prompt_no_styles': lambda self: self.prompt_no_style(), | |
'prompt_words': lambda self: self.prompt_words(), | |
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8], | |
'sampler': lambda self: self.p and self.p.sampler_name, | |
'seed': lambda self: self.seed and str(self.seed) or '', | |
'steps': lambda self: self.p and self.p.steps, | |
'styles': lambda self: self.p and ", ".join([style for style in self.p.styles if not style == "None"]) or "None", | |
'uuid': lambda self: str(uuid.uuid4()), | |
} | |
default_time_format = '%Y%m%d%H%M%S' | |
def __init__(self, p, seed, prompt, image, grid=False): | |
if p is None: | |
debug('Filename generator init skip') | |
else: | |
debug(f'Filename generator init: {seed} {prompt}') | |
self.p = p | |
if seed is not None and int(seed) > 0: | |
self.seed = seed | |
elif hasattr(p, 'all_seeds'): | |
self.seed = p.all_seeds[0] | |
else: | |
self.seed = 0 | |
self.prompt = prompt | |
self.image = image | |
if not grid: | |
self.batch_number = NOTHING if self.p is None or getattr(self.p, 'batch_size', 1) == 1 else (self.p.batch_index + 1 if hasattr(self.p, 'batch_index') else NOTHING) | |
self.iter_number = NOTHING if self.p is None or getattr(self.p, 'n_iter', 1) == 1 else (self.p.iteration + 1 if hasattr(self.p, 'iteration') else NOTHING) | |
else: | |
self.batch_number = NOTHING | |
self.iter_number = NOTHING | |
def hasprompt(self, *args): | |
lower = self.prompt.lower() | |
if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None: | |
return None | |
outres = "" | |
for arg in args: | |
if arg != "": | |
division = arg.split("|") | |
expected = division[0].lower() | |
default = division[1] if len(division) > 1 else "" | |
if lower.find(expected) >= 0: | |
outres = f'{outres}{expected}' | |
else: | |
outres = outres if default == "" else f'{outres}{default}' | |
return outres | |
def image_hash(self): | |
if getattr(self, 'image', None) is None: | |
return None | |
import base64 | |
from io import BytesIO | |
buffered = BytesIO() | |
self.image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
shorthash = hashlib.sha256(img_str).hexdigest()[0:8] | |
return shorthash | |
def prompt_full(self): | |
return self.prompt_sanitize(self.prompt) | |
def prompt_words(self): | |
if getattr(self, 'prompt', None) is None: | |
return '' | |
no_attention = re_attention.sub(r'\1', self.prompt) | |
no_network = re_network.sub(r'\1', no_attention) | |
no_brackets = re_brackets.sub('', no_network) | |
words = [x for x in re_nonletters.split(no_brackets or "") if len(x) > 0] | |
prompt = " ".join(words[0:shared.opts.directories_max_prompt_words]) | |
return self.prompt_sanitize(prompt) | |
def prompt_no_style(self): | |
if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None: | |
return None | |
prompt_no_style = self.prompt | |
for style in shared.prompt_styles.get_style_prompts(self.p.styles): | |
if len(style) > 0: | |
for part in style.split("{prompt}"): | |
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",") | |
prompt_no_style = prompt_no_style.replace(style, "") | |
return self.prompt_sanitize(prompt_no_style) | |
def datetime(self, *args): | |
import pytz | |
time_datetime = datetime.datetime.now() | |
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format | |
try: | |
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None | |
except pytz.exceptions.UnknownTimeZoneError: | |
time_zone = None | |
time_zone_time = time_datetime.astimezone(time_zone) | |
try: | |
formatted_time = time_zone_time.strftime(time_format) | |
except (ValueError, TypeError): | |
formatted_time = time_zone_time.strftime(self.default_time_format) | |
return formatted_time | |
def prompt_sanitize(self, prompt): | |
invalid_chars = '#<>:\'"\\|?*\n\t\r' | |
sanitized = prompt.translate({ ord(x): '_' for x in invalid_chars }).strip() | |
debug(f'Prompt sanitize: input="{prompt}" output={sanitized}') | |
return sanitized | |
def sanitize(self, filename): | |
invalid_chars = '\'"|?*\n\t\r' # <https://learn.microsoft.com/en-us/windows/win32/fileio/naming-a-file> | |
invalid_folder = ':' | |
invalid_files = ['CON', 'PRN', 'AUX', 'NUL', 'NULL', 'COM0', 'COM1', 'LPT0', 'LPT1'] | |
invalid_prefix = ', ' | |
invalid_suffix = '.,_ ' | |
fn, ext = os.path.splitext(filename) | |
parts = Path(fn).parts | |
newparts = [] | |
for i, part in enumerate(parts): | |
part = part.translate({ ord(x): '_' for x in invalid_chars }) | |
if i > 0 or (len(part) >= 2 and part[1] != invalid_folder): # skip drive, otherwise remove | |
part = part.translate({ ord(x): '_' for x in invalid_folder }) | |
part = part.lstrip(invalid_prefix).rstrip(invalid_suffix) | |
if part in invalid_files: # reserved names | |
[part := part.replace(word, '_') for word in invalid_files] # pylint: disable=expression-not-assigned | |
newparts.append(part) | |
fn = str(Path(*newparts)) | |
max_length = max(256 - len(ext), os.statvfs(__file__).f_namemax - 32 if hasattr(os, 'statvfs') else 256 - len(ext)) | |
while len(os.path.abspath(fn)) > max_length: | |
fn = fn[:-1] | |
fn += ext | |
debug(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}') | |
return fn | |
def sequence(self, x, dirname, basename): | |
if shared.opts.save_images_add_number or '[seq]' in x: | |
if '[seq]' not in x: | |
x = os.path.join(os.path.dirname(x), f"[seq]-{os.path.basename(x)}") | |
basecount = get_next_sequence_number(dirname, basename) | |
for i in range(9999): | |
seq = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" | |
filename = x.replace('[seq]', seq) | |
if not os.path.exists(filename): | |
debug(f'Prompt sequence: input="{x}" seq={seq} output="{filename}"') | |
x = filename | |
break | |
return x | |
def apply(self, x): | |
res = '' | |
for m in re_pattern.finditer(x): | |
text, pattern = m.groups() | |
if pattern is None: | |
res += text | |
continue | |
pattern_args = [] | |
while True: | |
m = re_pattern_arg.match(pattern) | |
if m is None: | |
break | |
pattern, arg = m.groups() | |
pattern_args.insert(0, arg) | |
fun = self.replacements.get(pattern.lower(), None) | |
if fun is not None: | |
try: | |
debug(f'Filename apply: pattern={pattern.lower()} args={pattern_args}') | |
replacement = fun(self, *pattern_args) | |
except Exception as e: | |
replacement = None | |
shared.log.error(f'Filename apply pattern: {x} {e}') | |
if replacement == NOTHING: | |
continue | |
if replacement is not None: | |
res += text + str(replacement).replace('/', '-').replace('\\', '-') | |
continue | |
else: | |
res += text + f'[{pattern}]' # reinsert unknown pattern | |
return res | |
def get_next_sequence_number(path, basename): | |
""" | |
Determines and returns the next sequence number to use when saving an image in the specified directory. | |
""" | |
result = -1 | |
if basename != '': | |
basename = f"{basename}-" | |
prefix_length = len(basename) | |
if not os.path.isdir(path): | |
return 0 | |
for p in os.listdir(path): | |
if p.startswith(basename): | |
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) | |
try: | |
result = max(int(parts[0]), result) | |
except ValueError: | |
pass | |
return result + 1 | |
def atomically_save_image(): | |
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes | |
while True: | |
image, filename, extension, params, exifinfo, filename_txt = save_queue.get() | |
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: | |
file.write(exifinfo) | |
fn = filename + extension | |
filename = filename.strip() | |
if extension[0] != '.': # add dot if missing | |
extension = '.' + extension | |
try: | |
image_format = Image.registered_extensions()[extension] | |
except Exception: | |
shared.log.warning(f'Saving: unknown image format: {extension}') | |
image_format = 'JPEG' | |
if shared.opts.image_watermark_enabled or (shared.opts.image_watermark_position != 'none' and shared.opts.image_watermark_image != ''): | |
image = set_watermark(image, shared.opts.image_watermark) | |
size = os.path.getsize(fn) if os.path.exists(fn) else 0 | |
shared.log.info(f'Saving: image="{fn}" type={image_format} resolution={image.width}x{image.height} size={size}') | |
# additional metadata saved in files | |
if shared.opts.save_txt and len(exifinfo) > 0: | |
try: | |
with open(filename_txt, "w", encoding="utf8") as file: | |
file.write(f"{exifinfo}\n") | |
shared.log.info(f'Saving: text="{filename_txt}" len={len(exifinfo)}') | |
except Exception as e: | |
shared.log.warning(f'Saving failed: description={filename_txt} {e}') | |
# actual save | |
exifinfo = (exifinfo or "") if shared.opts.image_metadata else "" | |
if image_format == 'PNG': | |
pnginfo_data = PngImagePlugin.PngInfo() | |
for k, v in params.pnginfo.items(): | |
pnginfo_data.add_text(k, str(v)) | |
save_args = { 'compress_level': 6, 'pnginfo': pnginfo_data if shared.opts.image_metadata else None } | |
elif image_format == 'JPEG': | |
if image.mode == 'RGBA': | |
shared.log.warning('Saving: removing alpha channel') | |
image = image.convert("RGB") | |
elif image.mode == 'I;16': | |
image = image.point(lambda p: p * 0.0038910505836576).convert("L") | |
exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } }) | |
save_args = { 'optimize': True, 'quality': shared.opts.jpeg_quality, 'exif': exif_bytes if shared.opts.image_metadata else None } | |
elif image_format == 'WEBP': | |
if image.mode == 'I;16': | |
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") | |
exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } }) | |
save_args = { 'optimize': True, 'quality': shared.opts.jpeg_quality, 'exif': exif_bytes if shared.opts.image_metadata else None, 'lossless': shared.opts.webp_lossless } | |
else: | |
save_args = { 'quality': shared.opts.jpeg_quality } | |
try: | |
image.save(fn, format=image_format, **save_args) | |
except Exception as e: | |
shared.log.error(f'Saving failed: file="{fn}" format={image_format} {e}') | |
if shared.opts.save_log_fn != '' and len(exifinfo) > 0: | |
fn = os.path.join(paths.data_path, shared.opts.save_log_fn) | |
if not fn.endswith('.json'): | |
fn += '.json' | |
entries = shared.readfile(fn, silent=True) | |
idx = len(list(entries)) | |
if idx == 0: | |
entries = [] | |
entry = { 'id': idx, 'filename': filename, 'time': datetime.datetime.now().isoformat(), 'info': exifinfo } | |
entries.append(entry) | |
shared.writefile(entries, fn, mode='w', silent=True) | |
shared.log.info(f'Saving: json="{fn}" records={len(entries)}') | |
save_queue.task_done() | |
save_queue = queue.Queue() | |
save_thread = threading.Thread(target=atomically_save_image, daemon=True) | |
save_thread.start() | |
def save_image(image, path, basename='', seed=None, prompt=None, extension=shared.opts.samples_format, info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix='', save_to_dirs=None): # pylint: disable=unused-argument | |
debug(f'Save: fn={sys._getframe(1).f_code.co_name}') # pylint: disable=protected-access | |
if image is None: | |
shared.log.warning('Image is none') | |
return None, None | |
if not check_grid_size([image]): | |
return None, None | |
if path is None or path == '': # set default path to avoid errors when functions are triggered manually or via api and param is not set | |
path = shared.opts.outdir_save | |
namegen = FilenameGenerator(p, seed, prompt, image, grid=grid) | |
suffix = suffix if suffix is not None else '' | |
basename = basename if basename is not None else '' | |
if shared.opts.save_to_dirs: | |
dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]") | |
path = os.path.join(path, dirname) | |
if forced_filename is None: | |
if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0: | |
file_decoration = shared.opts.samples_filename_pattern | |
else: | |
file_decoration = "[seq]-[prompt_words]" | |
file_decoration = namegen.apply(file_decoration) | |
file_decoration += suffix if suffix is not None else '' | |
filename = os.path.join(path, f"{file_decoration}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{file_decoration}.{extension}") | |
else: | |
forced_filename += suffix if suffix is not None else '' | |
filename = os.path.join(path, f"{forced_filename}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{forced_filename}.{extension}") | |
pnginfo = existing_info or {} | |
if info is not None: | |
pnginfo[pnginfo_section_name] = info | |
params = script_callbacks.ImageSaveParams(image, p, filename, pnginfo) | |
params.filename = namegen.sanitize(filename) | |
dirname = os.path.dirname(params.filename) | |
if dirname is not None and len(dirname) > 0: | |
os.makedirs(dirname, exist_ok=True) | |
params.filename = namegen.sequence(params.filename, dirname, basename) | |
params.filename = namegen.sanitize(params.filename) | |
# callbacks | |
script_callbacks.before_image_saved_callback(params) | |
exifinfo = params.pnginfo.get('UserComment', '') | |
exifinfo = (exifinfo + ', ' if len(exifinfo) > 0 else '') + params.pnginfo.get(pnginfo_section_name, '') | |
filename, extension = os.path.splitext(params.filename) | |
filename_txt = f"{filename}.txt" if shared.opts.save_txt and len(exifinfo) > 0 else None | |
save_queue.put((params.image, filename, extension, params, exifinfo, filename_txt)) # actual save is executed in a thread that polls data from queue | |
save_queue.join() | |
if not hasattr(params.image, 'already_saved_as'): | |
debug(f'Image marked: "{params.filename}"') | |
params.image.already_saved_as = params.filename | |
script_callbacks.image_saved_callback(params) | |
return params.filename, filename_txt | |
def save_video_atomic(images, filename, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3): | |
try: | |
import cv2 | |
except Exception as e: | |
shared.log.error(f'Save video: cv2: {e}') | |
return | |
os.makedirs(os.path.dirname(filename), exist_ok=True) | |
if video_type.lower() == 'mp4': | |
frames = images | |
if interpolate > 0: | |
try: | |
import modules.rife | |
frames = modules.rife.interpolate(images, count=interpolate, scale=scale, pad=pad, change=change) | |
except Exception as e: | |
shared.log.error(f'RIFE interpolation: {e}') | |
errors.display(e, 'RIFE interpolation') | |
video_frames = [np.array(frame) for frame in frames] | |
fourcc = "mp4v" | |
h, w, _c = video_frames[0].shape | |
video_writer = cv2.VideoWriter(filename, fourcc=cv2.VideoWriter_fourcc(*fourcc), fps=len(frames)/duration, frameSize=(w, h)) | |
for i in range(len(video_frames)): | |
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) | |
video_writer.write(img) | |
size = os.path.getsize(filename) | |
shared.log.info(f'Save video: file="{filename}" frames={len(frames)} duration={duration} fourcc={fourcc} size={size}') | |
if video_type.lower() == 'gif' or video_type.lower() == 'png': | |
append = images.copy() | |
image = append.pop(0) | |
if loop: | |
append += append[::-1] | |
frames=len(append) + 1 | |
image.save( | |
filename, | |
save_all = True, | |
append_images = append, | |
optimize = False, | |
duration = 1000.0 * duration / frames, | |
loop = 0 if loop else 1, | |
) | |
size = os.path.getsize(filename) | |
shared.log.info(f'Save video: file="{filename}" frames={len(append) + 1} duration={duration} loop={loop} size={size}') | |
def save_video(p, images, filename = None, video_type: str = 'none', duration: float = 2.0, loop: bool = False, interpolate: int = 0, scale: float = 1.0, pad: int = 1, change: float = 0.3, sync: bool = False): | |
if images is None or len(images) < 2 or video_type is None or video_type.lower() == 'none': | |
return | |
image = images[0] | |
if p is not None: | |
namegen = FilenameGenerator(p, seed=p.all_seeds[0], prompt=p.all_prompts[0], image=image) | |
else: | |
namegen = FilenameGenerator(None, seed=0, prompt='', image=image) | |
if filename is None and p is not None: | |
filename = namegen.apply(shared.opts.samples_filename_pattern if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0 else "[seq]-[prompt_words]") | |
filename = os.path.join(shared.opts.outdir_video, filename) | |
filename = namegen.sequence(filename, shared.opts.outdir_video, '') | |
else: | |
if os.pathsep not in filename: | |
filename = os.path.join(shared.opts.outdir_video, filename) | |
if not filename.lower().endswith(video_type.lower()): | |
filename += f'.{video_type.lower()}' | |
filename = namegen.sanitize(filename) | |
if not sync: | |
threading.Thread(target=save_video_atomic, args=(images, filename, video_type, duration, loop, interpolate, scale, pad, change)).start() | |
else: | |
save_video_atomic(images, filename, video_type, duration, loop, interpolate, scale, pad, change) | |
return filename | |
def safe_decode_string(s: bytes): | |
remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment | |
for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings | |
try: | |
s = remove_prefix(s, b'UNICODE') | |
s = remove_prefix(s, b'ASCII') | |
s = remove_prefix(s, b'\x00') | |
val = s.decode(encoding, errors="strict") | |
val = re.sub(r'[\x00-\x09]', '', val).strip() # remove remaining special characters | |
if len(val) == 0: # remove empty strings | |
val = None | |
return val | |
except Exception: | |
pass | |
return None | |
def read_info_from_image(image: Image): | |
items = image.info or {} | |
geninfo = items.pop('parameters', None) | |
if geninfo is None: | |
geninfo = items.pop('UserComment', None) | |
if geninfo is not None and len(geninfo) > 0: | |
if 'UserComment' in geninfo: | |
geninfo = geninfo['UserComment'] | |
items['UserComment'] = geninfo | |
if "exif" in items: | |
try: | |
exif = piexif.load(items["exif"]) | |
except Exception as e: | |
shared.log.error(f'Error loading EXIF data: {e}') | |
exif = {} | |
for _key, subkey in exif.items(): | |
if isinstance(subkey, dict): | |
for key, val in subkey.items(): | |
if isinstance(val, bytes): # decode bytestring | |
val = safe_decode_string(val) | |
if isinstance(val, tuple) and isinstance(val[0], int) and isinstance(val[1], int) and val[1] > 0: # convert camera ratios | |
val = round(val[0] / val[1], 2) | |
if val is not None and key in ExifTags.TAGS: # add known tags | |
if ExifTags.TAGS[key] == 'UserComment': # add geninfo from UserComment | |
geninfo = val | |
items['parameters'] = val | |
else: | |
items[ExifTags.TAGS[key]] = val | |
elif val is not None and key in ExifTags.GPSTAGS: | |
items[ExifTags.GPSTAGS[key]] = val | |
wm = get_watermark(image) | |
if wm != '': | |
# geninfo += f' Watermark: {wm}' | |
items['watermark'] = wm | |
for key, val in items.items(): | |
if isinstance(val, bytes): # decode bytestring | |
items[key] = safe_decode_string(val) | |
for key in ['exif', 'ExifOffset', 'JpegIFOffset', 'JpegIFByteCount', 'ExifVersion', 'icc_profile', 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'adobe', 'photoshop', 'loop', 'duration', 'dpi']: # remove unwanted tags | |
items.pop(key, None) | |
if items.get("Software", None) == "NovelAI": | |
try: | |
json_info = json.loads(items["Comment"]) | |
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a") | |
geninfo = f"""{items["Description"]} | |
Negative prompt: {json_info["uc"]} | |
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" | |
except Exception as e: | |
errors.display(e, 'novelai image parser') | |
try: | |
items['width'] = image.width | |
items['height'] = image.height | |
items['mode'] = image.mode | |
except Exception: | |
pass | |
return geninfo, items | |
def image_data(data): | |
import gradio as gr | |
if data is None: | |
return gr.update(), None | |
err1 = None | |
err2 = None | |
try: | |
image = Image.open(io.BytesIO(data)) | |
image.load() | |
info, _ = read_info_from_image(image) | |
errors.log.debug(f'Decoded object: image={image} metadata={info}') | |
return info, None | |
except Exception as e: | |
err1 = e | |
try: | |
if len(data) > 1024 * 10: | |
errors.log.warning(f'Error decoding object: data too long: {len(data)}') | |
return gr.update(), None | |
info = data.decode('utf8') | |
errors.log.debug(f'Decoded object: data={len(data)} metadata={info}') | |
return info, None | |
except Exception as e: | |
err2 = e | |
errors.log.error(f'Error decoding object: {err1 or err2}') | |
return gr.update(), None | |
def flatten(img, bgcolor): | |
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency""" | |
if img.mode == "RGBA": | |
background = Image.new('RGBA', img.size, bgcolor) | |
background.paste(img, mask=img) | |
img = background | |
return img.convert('RGB') | |
def set_watermark(image, watermark): | |
if shared.opts.image_watermark_position != 'none': # visible watermark | |
wm_image = None | |
try: | |
wm_image = Image.open(shared.opts.image_watermark_image) | |
except Exception as e: | |
shared.log.warning(f'Set image watermark: fn="{shared.opts.image_watermark_image}" {e}') | |
if wm_image is not None: | |
if shared.opts.image_watermark_position == 'top/left': | |
position = (0, 0) | |
elif shared.opts.image_watermark_position == 'top/right': | |
position = (image.width - wm_image.width, 0) | |
elif shared.opts.image_watermark_position == 'bottom/left': | |
position = (0, image.height - wm_image.height) | |
elif shared.opts.image_watermark_position == 'bottom/right': | |
position = (image.width - wm_image.width, image.height - wm_image.height) | |
elif shared.opts.image_watermark_position == 'center': | |
position = ((image.width - wm_image.width) // 2, (image.height - wm_image.height) // 2) | |
else: | |
position = (random.randint(0, image.width - wm_image.width), random.randint(0, image.height - wm_image.height)) | |
try: | |
for x in range(wm_image.width): | |
for y in range(wm_image.height): | |
r, g, b, _a = wm_image.getpixel((x, y)) | |
if not (r == 0 and g == 0 and b == 0): | |
image.putpixel((x+position[0], y+position[1]), (r, g, b)) | |
shared.log.debug(f'Set image watermark: fn="{shared.opts.image_watermark_image}" image={wm_image} position={position}') | |
except Exception as e: | |
shared.log.warning(f'Set image watermark: image={wm_image} {e}') | |
if shared.opts.image_watermark_enabled: # invisible watermark | |
from imwatermark import WatermarkEncoder | |
wm_type = 'bytes' | |
wm_method = 'dwtDctSvd' | |
wm_length = 32 | |
length = wm_length // 8 | |
info = image.info | |
data = np.asarray(image) | |
encoder = WatermarkEncoder() | |
text = f"{watermark:<{length}}"[:length] | |
bytearr = text.encode(encoding='ascii', errors='ignore') | |
try: | |
encoder.set_watermark(wm_type, bytearr) | |
encoded = encoder.encode(data, wm_method) | |
image = Image.fromarray(encoded) | |
image.info = info | |
shared.log.debug(f'Set invisible watermark: {watermark} method={wm_method} bits={wm_length}') | |
except Exception as e: | |
shared.log.warning(f'Set invisible watermark error: {watermark} method={wm_method} bits={wm_length} {e}') | |
return image | |
def get_watermark(image): | |
from imwatermark import WatermarkDecoder | |
wm_type = 'bytes' | |
wm_method = 'dwtDctSvd' | |
wm_length = 32 | |
data = np.asarray(image) | |
decoder = WatermarkDecoder(wm_type, wm_length) | |
try: | |
decoded = decoder.decode(data, wm_method) | |
wm = decoded.decode(encoding='ascii', errors='ignore') | |
except Exception: | |
wm = '' | |
return wm | |