File size: 7,641 Bytes
fa9f602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass

import torch

from modules import devices

import modules.scripts as scripts
import gradio as gr
# todo:
from modules.script_callbacks import CFGDenoisedParams, on_cfg_denoised

from modules.processing import StableDiffusionProcessing


@dataclass
class Division:
    y: float
    x: float


@dataclass
class Position:
    y: float
    x: float
    ey: float
    ex: float


class Filter:

    def __init__(self, division: Division, position: Position, weight: float):
        self.division = division
        self.position = position
        self.weight = weight

    def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch.Tensor:

        x = torch.zeros(num_channels, height_b, width_b).to(devices.device)

        division_height = height_b / self.division.y
        division_width = width_b / self.division.x
        y1 = int(division_height * self.position.y)
        y2 = int(division_height * self.position.ey)
        x1 = int(division_width * self.position.x)
        x2 = int(division_width * self.position.ex)

        x[:, y1:y2, x1:x2] = self.weight

        return x


class Script(scripts.Script):

    def __init__(self):
        self.num_batches: int = 0
        self.end_at_step: int = 20
        self.filters: List[Filter] = []
        self.debug: bool = False

    def title(self):
        return "Latent Couple extension"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def create_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str):

        divisions = []
        for division in raw_divisions.split(','):
            y, x = division.split(':')
            divisions.append(Division(float(y), float(x)))

        def start_and_end_position(raw: str):
            nums = [float(num) for num in raw.split('-')]
            if len(nums) == 1:
                return nums[0], nums[0] + 1.0
            else:
                return nums[0], nums[1]

        positions = []
        for position in raw_positions.split(','):
            y, x = position.split(':')
            y1, y2 = start_and_end_position(y)
            x1, x2 = start_and_end_position(x)
            positions.append(Position(y1, x1, y2, x2))

        weights = []
        for w in raw_weights.split(','):
            weights.append(float(w))

        # todo: assert len

        return [Filter(division, position, weight) for division, position, weight in zip(divisions, positions, weights)]

    def do_visualize(self, raw_divisions: str, raw_positions: str, raw_weights: str):

        self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights)

        return [f.create_tensor(1, 128, 128).squeeze(dim=0).cpu().numpy() for f in self.filters]

    def do_apply(self, extra_generation_params: str):
        #
        # parse "Latent Couple" extra_generation_params
        #
        raw_params = {}

        for assignment in extra_generation_params.split(' '):
            pair = assignment.split('=', 1)
            if len(pair) != 2:
                continue
            raw_params[pair[0]] = pair[1]

        return raw_params.get('divisions', '1:1,1:2,1:2'), raw_params.get('positions', '0:0,0:0,0:1'), raw_params.get('weights', '0.2,0.8,0.8'), int(raw_params.get('step', '20'))

    def ui(self, is_img2img):
        id_part = "img2img" if is_img2img else "txt2img"

        with gr.Group():
            with gr.Accordion("Latent Couple", open=False):
                enabled = gr.Checkbox(value=False, label="Enabled")
                with gr.Row():
                    divisions = gr.Textbox(label="Divisions", elem_id=f"cd_{id_part}_divisions", value="1:1,1:2,1:2")
                    positions = gr.Textbox(label="Positions", elem_id=f"cd_{id_part}_positions", value="0:0,0:0,0:1")
                with gr.Row():
                    weights = gr.Textbox(label="Weights", elem_id=f"cd_{id_part}_weights", value="0.2,0.8,0.8")
                    end_at_step = gr.Slider(minimum=0, maximum=150, step=1, label="end at this step", elem_id=f"cd_{id_part}_end_at_this_step", value=20)

                visualize_button = gr.Button(value="Visualize")
                visual_regions = gr.Gallery(label="Regions").style(grid=(4, 4, 4, 8), height="auto")

                visualize_button.click(fn=self.do_visualize, inputs=[divisions, positions, weights], outputs=[visual_regions])

                extra_generation_params = gr.Textbox(label="Extra generation params")
                apply_button = gr.Button(value="Apply")

                apply_button.click(fn=self.do_apply, inputs=[extra_generation_params], outputs=[divisions, positions, weights, end_at_step])

        self.infotext_fields = [
            (extra_generation_params, "Latent Couple")
        ]
        return enabled, divisions, positions, weights, end_at_step

    def denoised_callback(self, params: CFGDenoisedParams):

        if self.enabled and params.sampling_step < self.end_at_step:

            x = params.x
            # x.shape = [batch_size, C, H // 8, W // 8]

            num_batches = self.num_batches
            num_prompts = x.shape[0] // num_batches
            # ex. num_batches = 3
            # ex. num_prompts = 3 (tensor) + 1 (uncond)

            if self.debug:
                print(f"### Latent couple ###")
                print(f"denoised_callback x.shape={x.shape} num_batches={num_batches} num_prompts={num_prompts}")

            filters = [
                f.create_tensor(x.shape[1], x.shape[2], x.shape[3]) for f in self.filters
            ]
            neg_filters = [1.0 - f for f in filters]

            """
            batch #1
              subprompt #1
              subprompt #2
              subprompt #3
            batch #2
              subprompt #1
              subprompt #2
              subprompt #3
            uncond
              batch #1
              batch #2
            """

            tensor_off = 0
            uncond_off = num_batches * num_prompts - num_batches
            for b in range(num_batches):
                uncond = x[uncond_off, :, :, :]

                for p in range(num_prompts - 1):
                    if self.debug:
                        print(f"b={b} p={p}")
                    if p < len(filters):
                        tensor = x[tensor_off, :, :, :]
                        x[tensor_off, :, :, :] = tensor * filters[p] + uncond * neg_filters[p]

                    tensor_off += 1

                uncond_off += 1

    def process(self, p: StableDiffusionProcessing, enabled: bool, raw_divisions: str, raw_positions: str, raw_weights: str, raw_end_at_step: int):

        self.enabled = enabled

        if not self.enabled:
            return

        self.num_batches = p.batch_size

        self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights)

        self.end_at_step = raw_end_at_step

        #

        if self.end_at_step != 0:
            p.extra_generation_params["Latent Couple"] = f"divisions={raw_divisions} positions={raw_positions} weights={raw_weights} end at step={raw_end_at_step}"
            # save params into the output file as PNG textual data.

        if self.debug:
            print(f"### Latent couple ###")
            print(f"process num_batches={self.num_batches} end_at_step={self.end_at_step}")

        if not hasattr(self, 'callbacks_added'):
            on_cfg_denoised(self.denoised_callback)
            self.callbacks_added = True

        return

    def postprocess(self, *args):
        return