|
from typing import List, Dict, Optional, Tuple |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
|
|
from modules import devices |
|
|
|
import modules.scripts as scripts |
|
import gradio as gr |
|
|
|
from modules.script_callbacks import CFGDenoisedParams, on_cfg_denoised |
|
|
|
from modules.processing import StableDiffusionProcessing |
|
|
|
|
|
@dataclass |
|
class Division: |
|
y: float |
|
x: float |
|
|
|
|
|
@dataclass |
|
class Position: |
|
y: float |
|
x: float |
|
ey: float |
|
ex: float |
|
|
|
|
|
class Filter: |
|
|
|
def __init__(self, division: Division, position: Position, weight: float): |
|
self.division = division |
|
self.position = position |
|
self.weight = weight |
|
|
|
def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch.Tensor: |
|
|
|
x = torch.zeros(num_channels, height_b, width_b).to(devices.device) |
|
|
|
division_height = height_b / self.division.y |
|
division_width = width_b / self.division.x |
|
y1 = int(division_height * self.position.y) |
|
y2 = int(division_height * self.position.ey) |
|
x1 = int(division_width * self.position.x) |
|
x2 = int(division_width * self.position.ex) |
|
|
|
x[:, y1:y2, x1:x2] = self.weight |
|
|
|
return x |
|
|
|
|
|
class Script(scripts.Script): |
|
|
|
def __init__(self): |
|
self.num_batches: int = 0 |
|
self.end_at_step: int = 20 |
|
self.filters: List[Filter] = [] |
|
self.debug: bool = False |
|
|
|
def title(self): |
|
return "Latent Couple extension" |
|
|
|
def show(self, is_img2img): |
|
return scripts.AlwaysVisible |
|
|
|
def create_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str): |
|
|
|
divisions = [] |
|
for division in raw_divisions.split(','): |
|
y, x = division.split(':') |
|
divisions.append(Division(float(y), float(x))) |
|
|
|
def start_and_end_position(raw: str): |
|
nums = [float(num) for num in raw.split('-')] |
|
if len(nums) == 1: |
|
return nums[0], nums[0] + 1.0 |
|
else: |
|
return nums[0], nums[1] |
|
|
|
positions = [] |
|
for position in raw_positions.split(','): |
|
y, x = position.split(':') |
|
y1, y2 = start_and_end_position(y) |
|
x1, x2 = start_and_end_position(x) |
|
positions.append(Position(y1, x1, y2, x2)) |
|
|
|
weights = [] |
|
for w in raw_weights.split(','): |
|
weights.append(float(w)) |
|
|
|
|
|
|
|
return [Filter(division, position, weight) for division, position, weight in zip(divisions, positions, weights)] |
|
|
|
def do_visualize(self, raw_divisions: str, raw_positions: str, raw_weights: str): |
|
|
|
self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) |
|
|
|
return [f.create_tensor(1, 128, 128).squeeze(dim=0).cpu().numpy() for f in self.filters] |
|
|
|
def do_apply(self, extra_generation_params: str): |
|
|
|
|
|
|
|
raw_params = {} |
|
|
|
for assignment in extra_generation_params.split(' '): |
|
pair = assignment.split('=', 1) |
|
if len(pair) != 2: |
|
continue |
|
raw_params[pair[0]] = pair[1] |
|
|
|
return raw_params.get('divisions', '1:1,1:2,1:2'), raw_params.get('positions', '0:0,0:0,0:1'), raw_params.get('weights', '0.2,0.8,0.8'), int(raw_params.get('step', '20')) |
|
|
|
def ui(self, is_img2img): |
|
id_part = "img2img" if is_img2img else "txt2img" |
|
|
|
with gr.Group(): |
|
with gr.Accordion("Latent Couple", open=False): |
|
enabled = gr.Checkbox(value=False, label="Enabled") |
|
with gr.Row(): |
|
divisions = gr.Textbox(label="Divisions", elem_id=f"cd_{id_part}_divisions", value="1:1,1:2,1:2") |
|
positions = gr.Textbox(label="Positions", elem_id=f"cd_{id_part}_positions", value="0:0,0:0,0:1") |
|
with gr.Row(): |
|
weights = gr.Textbox(label="Weights", elem_id=f"cd_{id_part}_weights", value="0.2,0.8,0.8") |
|
end_at_step = gr.Slider(minimum=0, maximum=150, step=1, label="end at this step", elem_id=f"cd_{id_part}_end_at_this_step", value=20) |
|
|
|
visualize_button = gr.Button(value="Visualize") |
|
visual_regions = gr.Gallery(label="Regions").style(grid=(4, 4, 4, 8), height="auto") |
|
|
|
visualize_button.click(fn=self.do_visualize, inputs=[divisions, positions, weights], outputs=[visual_regions]) |
|
|
|
extra_generation_params = gr.Textbox(label="Extra generation params") |
|
apply_button = gr.Button(value="Apply") |
|
|
|
apply_button.click(fn=self.do_apply, inputs=[extra_generation_params], outputs=[divisions, positions, weights, end_at_step]) |
|
|
|
self.infotext_fields = [ |
|
(extra_generation_params, "Latent Couple") |
|
] |
|
return enabled, divisions, positions, weights, end_at_step |
|
|
|
def denoised_callback(self, params: CFGDenoisedParams): |
|
|
|
if self.enabled and params.sampling_step < self.end_at_step: |
|
|
|
x = params.x |
|
|
|
|
|
num_batches = self.num_batches |
|
num_prompts = x.shape[0] // num_batches |
|
|
|
|
|
|
|
if self.debug: |
|
print(f"### Latent couple ###") |
|
print(f"denoised_callback x.shape={x.shape} num_batches={num_batches} num_prompts={num_prompts}") |
|
|
|
filters = [ |
|
f.create_tensor(x.shape[1], x.shape[2], x.shape[3]) for f in self.filters |
|
] |
|
neg_filters = [1.0 - f for f in filters] |
|
|
|
""" |
|
batch #1 |
|
subprompt #1 |
|
subprompt #2 |
|
subprompt #3 |
|
batch #2 |
|
subprompt #1 |
|
subprompt #2 |
|
subprompt #3 |
|
uncond |
|
batch #1 |
|
batch #2 |
|
""" |
|
|
|
tensor_off = 0 |
|
uncond_off = num_batches * num_prompts - num_batches |
|
for b in range(num_batches): |
|
uncond = x[uncond_off, :, :, :] |
|
|
|
for p in range(num_prompts - 1): |
|
if self.debug: |
|
print(f"b={b} p={p}") |
|
if p < len(filters): |
|
tensor = x[tensor_off, :, :, :] |
|
x[tensor_off, :, :, :] = tensor * filters[p] + uncond * neg_filters[p] |
|
|
|
tensor_off += 1 |
|
|
|
uncond_off += 1 |
|
|
|
def process(self, p: StableDiffusionProcessing, enabled: bool, raw_divisions: str, raw_positions: str, raw_weights: str, raw_end_at_step: int): |
|
|
|
self.enabled = enabled |
|
|
|
if not self.enabled: |
|
return |
|
|
|
self.num_batches = p.batch_size |
|
|
|
self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights) |
|
|
|
self.end_at_step = raw_end_at_step |
|
|
|
|
|
|
|
if self.end_at_step != 0: |
|
p.extra_generation_params["Latent Couple"] = f"divisions={raw_divisions} positions={raw_positions} weights={raw_weights} end at step={raw_end_at_step}" |
|
|
|
|
|
if self.debug: |
|
print(f"### Latent couple ###") |
|
print(f"process num_batches={self.num_batches} end_at_step={self.end_at_step}") |
|
|
|
if not hasattr(self, 'callbacks_added'): |
|
on_cfg_denoised(self.denoised_callback) |
|
self.callbacks_added = True |
|
|
|
return |
|
|
|
def postprocess(self, *args): |
|
return |
|
|
|
|
|
|