supertori commited on
Commit
fa9f602
1 Parent(s): 1cfc196

Upload two_shot.py

Browse files
Files changed (1) hide show
  1. two_shot.py +227 -0
two_shot.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Tuple
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+
6
+ from modules import devices
7
+
8
+ import modules.scripts as scripts
9
+ import gradio as gr
10
+ # todo:
11
+ from modules.script_callbacks import CFGDenoisedParams, on_cfg_denoised
12
+
13
+ from modules.processing import StableDiffusionProcessing
14
+
15
+
16
+ @dataclass
17
+ class Division:
18
+ y: float
19
+ x: float
20
+
21
+
22
+ @dataclass
23
+ class Position:
24
+ y: float
25
+ x: float
26
+ ey: float
27
+ ex: float
28
+
29
+
30
+ class Filter:
31
+
32
+ def __init__(self, division: Division, position: Position, weight: float):
33
+ self.division = division
34
+ self.position = position
35
+ self.weight = weight
36
+
37
+ def create_tensor(self, num_channels: int, height_b: int, width_b: int) -> torch.Tensor:
38
+
39
+ x = torch.zeros(num_channels, height_b, width_b).to(devices.device)
40
+
41
+ division_height = height_b / self.division.y
42
+ division_width = width_b / self.division.x
43
+ y1 = int(division_height * self.position.y)
44
+ y2 = int(division_height * self.position.ey)
45
+ x1 = int(division_width * self.position.x)
46
+ x2 = int(division_width * self.position.ex)
47
+
48
+ x[:, y1:y2, x1:x2] = self.weight
49
+
50
+ return x
51
+
52
+
53
+ class Script(scripts.Script):
54
+
55
+ def __init__(self):
56
+ self.num_batches: int = 0
57
+ self.end_at_step: int = 20
58
+ self.filters: List[Filter] = []
59
+ self.debug: bool = False
60
+
61
+ def title(self):
62
+ return "Latent Couple extension"
63
+
64
+ def show(self, is_img2img):
65
+ return scripts.AlwaysVisible
66
+
67
+ def create_filters_from_ui_params(self, raw_divisions: str, raw_positions: str, raw_weights: str):
68
+
69
+ divisions = []
70
+ for division in raw_divisions.split(','):
71
+ y, x = division.split(':')
72
+ divisions.append(Division(float(y), float(x)))
73
+
74
+ def start_and_end_position(raw: str):
75
+ nums = [float(num) for num in raw.split('-')]
76
+ if len(nums) == 1:
77
+ return nums[0], nums[0] + 1.0
78
+ else:
79
+ return nums[0], nums[1]
80
+
81
+ positions = []
82
+ for position in raw_positions.split(','):
83
+ y, x = position.split(':')
84
+ y1, y2 = start_and_end_position(y)
85
+ x1, x2 = start_and_end_position(x)
86
+ positions.append(Position(y1, x1, y2, x2))
87
+
88
+ weights = []
89
+ for w in raw_weights.split(','):
90
+ weights.append(float(w))
91
+
92
+ # todo: assert len
93
+
94
+ return [Filter(division, position, weight) for division, position, weight in zip(divisions, positions, weights)]
95
+
96
+ def do_visualize(self, raw_divisions: str, raw_positions: str, raw_weights: str):
97
+
98
+ self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights)
99
+
100
+ return [f.create_tensor(1, 128, 128).squeeze(dim=0).cpu().numpy() for f in self.filters]
101
+
102
+ def do_apply(self, extra_generation_params: str):
103
+ #
104
+ # parse "Latent Couple" extra_generation_params
105
+ #
106
+ raw_params = {}
107
+
108
+ for assignment in extra_generation_params.split(' '):
109
+ pair = assignment.split('=', 1)
110
+ if len(pair) != 2:
111
+ continue
112
+ raw_params[pair[0]] = pair[1]
113
+
114
+ return raw_params.get('divisions', '1:1,1:2,1:2'), raw_params.get('positions', '0:0,0:0,0:1'), raw_params.get('weights', '0.2,0.8,0.8'), int(raw_params.get('step', '20'))
115
+
116
+ def ui(self, is_img2img):
117
+ id_part = "img2img" if is_img2img else "txt2img"
118
+
119
+ with gr.Group():
120
+ with gr.Accordion("Latent Couple", open=False):
121
+ enabled = gr.Checkbox(value=False, label="Enabled")
122
+ with gr.Row():
123
+ divisions = gr.Textbox(label="Divisions", elem_id=f"cd_{id_part}_divisions", value="1:1,1:2,1:2")
124
+ positions = gr.Textbox(label="Positions", elem_id=f"cd_{id_part}_positions", value="0:0,0:0,0:1")
125
+ with gr.Row():
126
+ weights = gr.Textbox(label="Weights", elem_id=f"cd_{id_part}_weights", value="0.2,0.8,0.8")
127
+ end_at_step = gr.Slider(minimum=0, maximum=150, step=1, label="end at this step", elem_id=f"cd_{id_part}_end_at_this_step", value=20)
128
+
129
+ visualize_button = gr.Button(value="Visualize")
130
+ visual_regions = gr.Gallery(label="Regions").style(grid=(4, 4, 4, 8), height="auto")
131
+
132
+ visualize_button.click(fn=self.do_visualize, inputs=[divisions, positions, weights], outputs=[visual_regions])
133
+
134
+ extra_generation_params = gr.Textbox(label="Extra generation params")
135
+ apply_button = gr.Button(value="Apply")
136
+
137
+ apply_button.click(fn=self.do_apply, inputs=[extra_generation_params], outputs=[divisions, positions, weights, end_at_step])
138
+
139
+ self.infotext_fields = [
140
+ (extra_generation_params, "Latent Couple")
141
+ ]
142
+ return enabled, divisions, positions, weights, end_at_step
143
+
144
+ def denoised_callback(self, params: CFGDenoisedParams):
145
+
146
+ if self.enabled and params.sampling_step < self.end_at_step:
147
+
148
+ x = params.x
149
+ # x.shape = [batch_size, C, H // 8, W // 8]
150
+
151
+ num_batches = self.num_batches
152
+ num_prompts = x.shape[0] // num_batches
153
+ # ex. num_batches = 3
154
+ # ex. num_prompts = 3 (tensor) + 1 (uncond)
155
+
156
+ if self.debug:
157
+ print(f"### Latent couple ###")
158
+ print(f"denoised_callback x.shape={x.shape} num_batches={num_batches} num_prompts={num_prompts}")
159
+
160
+ filters = [
161
+ f.create_tensor(x.shape[1], x.shape[2], x.shape[3]) for f in self.filters
162
+ ]
163
+ neg_filters = [1.0 - f for f in filters]
164
+
165
+ """
166
+ batch #1
167
+ subprompt #1
168
+ subprompt #2
169
+ subprompt #3
170
+ batch #2
171
+ subprompt #1
172
+ subprompt #2
173
+ subprompt #3
174
+ uncond
175
+ batch #1
176
+ batch #2
177
+ """
178
+
179
+ tensor_off = 0
180
+ uncond_off = num_batches * num_prompts - num_batches
181
+ for b in range(num_batches):
182
+ uncond = x[uncond_off, :, :, :]
183
+
184
+ for p in range(num_prompts - 1):
185
+ if self.debug:
186
+ print(f"b={b} p={p}")
187
+ if p < len(filters):
188
+ tensor = x[tensor_off, :, :, :]
189
+ x[tensor_off, :, :, :] = tensor * filters[p] + uncond * neg_filters[p]
190
+
191
+ tensor_off += 1
192
+
193
+ uncond_off += 1
194
+
195
+ def process(self, p: StableDiffusionProcessing, enabled: bool, raw_divisions: str, raw_positions: str, raw_weights: str, raw_end_at_step: int):
196
+
197
+ self.enabled = enabled
198
+
199
+ if not self.enabled:
200
+ return
201
+
202
+ self.num_batches = p.batch_size
203
+
204
+ self.filters = self.create_filters_from_ui_params(raw_divisions, raw_positions, raw_weights)
205
+
206
+ self.end_at_step = raw_end_at_step
207
+
208
+ #
209
+
210
+ if self.end_at_step != 0:
211
+ p.extra_generation_params["Latent Couple"] = f"divisions={raw_divisions} positions={raw_positions} weights={raw_weights} end at step={raw_end_at_step}"
212
+ # save params into the output file as PNG textual data.
213
+
214
+ if self.debug:
215
+ print(f"### Latent couple ###")
216
+ print(f"process num_batches={self.num_batches} end_at_step={self.end_at_step}")
217
+
218
+ if not hasattr(self, 'callbacks_added'):
219
+ on_cfg_denoised(self.denoised_callback)
220
+ self.callbacks_added = True
221
+
222
+ return
223
+
224
+ def postprocess(self, *args):
225
+ return
226
+
227
+