VMSI's picture
Upload 49 files
4caec08
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
import torch
from modules import devices
import modules.scripts as scripts
import gradio as gr
# todo:
from modules.script_callbacks import CFGDenoisedParams, on_cfg_denoised
from modules.processing import StableDiffusionProcessing
@dataclass
class Division:
y: float
x: float
@dataclass
class Position:
y: float
x: float
ey: float
ex: float
class Filter:
def __init__(self, division: Division, position: Position, weight: float):
self.division = division
self.position = position
self.weight = weight
def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch.Tensor:
x = torch.zeros(num_channels, height_b, width_b).to(devices.device)
division_height = height_b / self.division.y
division_width = width_b / self.division.x
y1 = int(division_height * self.position.y)
y2 = int(division_height * self.position.ey)
x1 = int(division_width * self.position.x)
x2 = int(division_width * self.position.ex)
x[:, y1:y2, x1:x2] = self.weight
return x
class Script(scripts.Script):
def __init__(self):
self.num_batches: int = 0
self.end_at_step: int = 20
self.filters: List[Filter] = []
self.debug: bool = False
def title(self):
return "Latent Couple extension"
def show(self, is_img2img):
return scripts.AlwaysVisible
def create_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str):
divisions = []
for division in raw_divisions.split(','):
y, x = division.split(':')
divisions.append(Division(float(y), float(x)))
def start_and_end_position(raw: str):
nums = [float(num) for num in raw.split('-')]
if len(nums) == 1:
return nums[0], nums[0] + 1.0
else:
return nums[0], nums[1]
positions = []
for position in raw_positions.split(','):
y, x = position.split(':')
y1, y2 = start_and_end_position(y)
x1, x2 = start_and_end_position(x)
positions.append(Position(y1, x1, y2, x2))
weights = []
for w in raw_weights.split(','):
weights.append(float(w))
# todo: assert len
return [Filter(division, position, weight) for division, position, weight in zip(divisions, positions, weights)]
def do_visualize(self, raw_divisions: str, raw_positions: str, raw_weights: str):
self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights)
return [f.create_tensor(1, 128, 128).squeeze(dim=0).cpu().numpy() for f in self.filters]
def do_apply(self, extra_generation_params: str):
#
# parse "Latent Couple" extra_generation_params
#
raw_params = {}
for assignment in extra_generation_params.split(' '):
pair = assignment.split('=', 1)
if len(pair) != 2:
continue
raw_params[pair[0]] = pair[1]
return raw_params.get('divisions', '1:1,1:2,1:2'), raw_params.get('positions', '0:0,0:0,0:1'), raw_params.get('weights', '0.2,0.8,0.8'), int(raw_params.get('step', '20'))
def ui(self, is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
with gr.Group():
with gr.Accordion("Latent Couple", open=False):
enabled = gr.Checkbox(value=False, label="Enabled")
with gr.Row():
divisions = gr.Textbox(label="Divisions", elem_id=f"cd_{id_part}_divisions", value="1:1,1:2,1:2")
positions = gr.Textbox(label="Positions", elem_id=f"cd_{id_part}_positions", value="0:0,0:0,0:1")
with gr.Row():
weights = gr.Textbox(label="Weights", elem_id=f"cd_{id_part}_weights", value="0.2,0.8,0.8")
end_at_step = gr.Slider(minimum=0, maximum=150, step=1, label="end at this step", elem_id=f"cd_{id_part}_end_at_this_step", value=20)
visualize_button = gr.Button(value="Visualize")
visual_regions = gr.Gallery(label="Regions").style(grid=(4, 4, 4, 8), height="auto")
visualize_button.click(fn=self.do_visualize, inputs=[divisions, positions, weights], outputs=[visual_regions])
extra_generation_params = gr.Textbox(label="Extra generation params")
apply_button = gr.Button(value="Apply")
apply_button.click(fn=self.do_apply, inputs=[extra_generation_params], outputs=[divisions, positions, weights, end_at_step])
self.infotext_fields = [
(extra_generation_params, "Latent Couple")
]
return enabled, divisions, positions, weights, end_at_step
def denoised_callback(self, params: CFGDenoisedParams):
if self.enabled and params.sampling_step < self.end_at_step:
x = params.x
# x.shape = [batch_size, C, H // 8, W // 8]
num_batches = self.num_batches
num_prompts = x.shape[0] // num_batches
# ex. num_batches = 3
# ex. num_prompts = 3 (tensor) + 1 (uncond)
if self.debug:
print(f"### Latent couple ###")
print(f"denoised_callback x.shape={x.shape} num_batches={num_batches} num_prompts={num_prompts}")
filters = [
f.create_tensor(x.shape[1], x.shape[2], x.shape[3]) for f in self.filters
]
neg_filters = [1.0 - f for f in filters]
"""
batch #1
subprompt #1
subprompt #2
subprompt #3
batch #2
subprompt #1
subprompt #2
subprompt #3
uncond
batch #1
batch #2
"""
tensor_off = 0
uncond_off = num_batches * num_prompts - num_batches
for b in range(num_batches):
uncond = x[uncond_off, :, :, :]
for p in range(num_prompts - 1):
if self.debug:
print(f"b={b} p={p}")
if p < len(filters):
tensor = x[tensor_off, :, :, :]
x[tensor_off, :, :, :] = tensor * filters[p] + uncond * neg_filters[p]
tensor_off += 1
uncond_off += 1
def process(self, p: StableDiffusionProcessing, enabled: bool, raw_divisions: str, raw_positions: str, raw_weights: str, raw_end_at_step: int):
self.enabled = enabled
if not self.enabled:
return
self.num_batches = p.batch_size
self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights)
self.end_at_step = raw_end_at_step
#
if self.end_at_step != 0:
p.extra_generation_params["Latent Couple"] = f"divisions={raw_divisions} positions={raw_positions} weights={raw_weights} end at step={raw_end_at_step}"
# save params into the output file as PNG textual data.
if self.debug:
print(f"### Latent couple ###")
print(f"process num_batches={self.num_batches} end_at_step={self.end_at_step}")
if not hasattr(self, 'callbacks_added'):
on_cfg_denoised(self.denoised_callback)
self.callbacks_added = True
return
def postprocess(self, *args):
return