bigjoker's picture
Duplicate from user238921933/stable-diffusion-webui
55cc64a
raw
history blame contribute delete
No virus
7.61 kB
# At the moment there are three types of masks: mask from variable, file mask and word mask
# Variable masks include init_mask for the predefined whole-video mask, frame_mask from video-masking system
# and human_mask for a model which better segments people in the background video
# They are put in {}-brackets
# Word masks are framed with <>-bracets, like: <cat>, <anime girl>
# File masks are put in []-brackes
# Empty strings are counted as the whole frame
# We want to put them all into a sequence of boolean operations
# Example:
# \ <armor>
# (({human_mask} & [mask1.png]) ^ <apple>)
# Writing the parser for the boolean sequence
# using regex and PIL operations
import re
from .load_images import get_mask_from_file, check_mask_for_errors, blank_if_none
from .word_masking import get_word_mask
from torch import Tensor
import PIL
from PIL import Image, ImageChops
# val_masks: name, PIL Image mask
# Returns an image in mode '1' (needed for bool ops), convert to 'L' in the sender function
def compose_mask(root, args, mask_seq, val_masks, frame_image, inner_idx:int = 0):
# Compose_mask recursively: go to inner brackets, then b-op it and go upstack
# Step 1:
# recursive parenthesis pass
# regex is not powerful here
seq = ""
inner_seq = ""
parentheses_counter = 0
for c in mask_seq:
if c == ')':
parentheses_counter = parentheses_counter - 1
if parentheses_counter > 0:
inner_seq += c
if c == '(':
parentheses_counter = parentheses_counter + 1
if parentheses_counter == 0:
if len(inner_seq) > 0:
inner_idx += 1
seq += compose_mask(root, args, inner_seq, val_masks, frame_image, inner_idx)
inner_seq = ""
else:
seq += c
if parentheses_counter != 0:
raise Exception('Mismatched parentheses in {mask_seq}!')
mask_seq = seq
# Step 2:
# Load the word masks and file masks as vars
# File masks
pattern = r'\[(?P<inner>[\S\s]*?)\]'
def parse(match_object):
nonlocal inner_idx
inner_idx += 1
content = match_object.groupdict()['inner']
val_masks[str(inner_idx)] = get_mask_from_file(content, args).convert('1') # TODO: add caching
return f"{{{inner_idx}}}"
mask_seq = re.sub(pattern, parse, mask_seq)
# Word masks
pattern = r'<(?P<inner>[\S\s]*?)>'
def parse(match_object):
nonlocal inner_idx
inner_idx += 1
content = match_object.groupdict()['inner']
val_masks[str(inner_idx)] = get_word_mask(root, frame_image, content).convert('1')
return f"{{{inner_idx}}}"
mask_seq = re.sub(pattern, parse, mask_seq)
# Now that all inner parenthesis are eliminated we're left with a linear string
# Step 3:
# Boolean operations with masks
# Operators: invert !, and &, or |, xor ^, difference \
# Invert vars with '!'
pattern = r'![\S\s]*{(?P<inner>[\S\s]*?)}'
def parse(match_object):
nonlocal inner_idx
inner_idx += 1
content = match_object.groupdict()['inner']
savename = content
if content in root.mask_preset_names:
inner_idx += 1
savename = str(inner_idx)
val_masks[savename] = ImageChops.invert(val_masks[content])
return f"{{{savename}}}"
mask_seq = re.sub(pattern, parse, mask_seq)
# Multiply neighbouring vars with '&'
# Wait for replacements stall (like in Markov chains)
while True:
pattern = r'{(?P<inner1>[\S\s]*?)}[\s]*&[\s]*{(?P<inner2>[\S\s]*?)}'
def parse(match_object):
nonlocal inner_idx
inner_idx += 1
content = match_object.groupdict()['inner1']
content_second = match_object.groupdict()['inner2']
savename = content
if content in root.mask_preset_names:
inner_idx += 1
savename = str(inner_idx)
val_masks[savename] = ImageChops.logical_and(val_masks[content], val_masks[content_second])
return f"{{{savename}}}"
prev_mask_seq = mask_seq
mask_seq = re.sub(pattern, parse, mask_seq)
if mask_seq is prev_mask_seq:
break
# Add neighbouring vars with '|'
while True:
pattern = r'{(?P<inner1>[\S\s]*?)}[\s]*?\|[\s]*?{(?P<inner2>[\S\s]*?)}'
def parse(match_object):
nonlocal inner_idx
inner_idx += 1
content = match_object.groupdict()['inner1']
content_second = match_object.groupdict()['inner2']
savename = content
if content in root.mask_preset_names:
inner_idx += 1
savename = str(inner_idx)
val_masks[savename] = ImageChops.logical_or(val_masks[content], val_masks[content_second])
return f"{{{savename}}}"
prev_mask_seq = mask_seq
mask_seq = re.sub(pattern, parse, mask_seq)
if mask_seq is prev_mask_seq:
break
# Mutually exclude neighbouring vars with '^'
while True:
pattern = r'{(?P<inner1>[\S\s]*?)}[\s]*\^[\s]*{(?P<inner2>[\S\s]*?)}'
def parse(match_object):
nonlocal inner_idx
inner_idx += 1
content = match_object.groupdict()['inner1']
content_second = match_object.groupdict()['inner2']
savename = content
if content in root.mask_preset_names:
inner_idx += 1
savename = str(inner_idx)
val_masks[savename] = ImageChops.logical_xor(val_masks[content], val_masks[content_second])
return f"{{{savename}}}"
prev_mask_seq = mask_seq
mask_seq = re.sub(pattern, parse, mask_seq)
if mask_seq is prev_mask_seq:
break
# Set-difference the regions with '\'
while True:
pattern = r'{(?P<inner1>[\S\s]*?)}[\s]*\\[\s]*{(?P<inner2>[\S\s]*?)}'
def parse(match_object):
content = match_object.groupdict()['inner1']
content_second = match_object.groupdict()['inner2']
savename = content
if content in root.mask_preset_names:
nonlocal inner_idx
inner_idx += 1
savename = str(inner_idx)
val_masks[savename] = ImageChops.logical_and(val_masks[content], ImageChops.invert(val_masks[content_second]))
return f"{{{savename}}}"
prev_mask_seq = mask_seq
mask_seq = re.sub(pattern, parse, mask_seq)
if mask_seq is prev_mask_seq:
break
# Step 4:
# Output
# Now we should have a single var left to return. If not, raise an error message
pattern = r'{(?P<inner>[\S\s]*?)}'
matches = re.findall(pattern, mask_seq)
if len(matches) != 1:
raise Exception(f'Wrong composable mask expression format! Broken mask sequence: {mask_seq}')
return f"{{{matches[0]}}}"
def compose_mask_with_check(root, args, mask_seq, val_masks, frame_image):
for k, v in val_masks.items():
val_masks[k] = blank_if_none(v, args.W, args.H, '1').convert('1')
return check_mask_for_errors(val_masks[compose_mask(root, args, mask_seq, val_masks, frame_image, 0)[1:-1]].convert('L'))