RSPMetaAdmin commited on
Commit
be2e1d9
1 Parent(s): c67abf8

Upload 2 files

Browse files
Files changed (2) hide show
  1. outpainting_mk_2.py +295 -0
  2. poor_mans_outpainting.py +146 -0
outpainting_mk_2.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import skimage
5
+
6
+ import modules.scripts as scripts
7
+ import gradio as gr
8
+ from PIL import Image, ImageDraw
9
+
10
+ from modules import images
11
+ from modules.processing import Processed, process_images
12
+ from modules.shared import opts, state
13
+
14
+
15
+ # this function is taken from https://github.com/parlance-zz/g-diffuser-bot
16
+ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
17
+ # helper fft routines that keep ortho normalization and auto-shift before and after fft
18
+ def _fft2(data):
19
+ if data.ndim > 2: # has channels
20
+ out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
21
+ for c in range(data.shape[2]):
22
+ c_data = data[:, :, c]
23
+ out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
24
+ out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
25
+ else: # one channel
26
+ out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
27
+ out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
28
+ out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
29
+
30
+ return out_fft
31
+
32
+ def _ifft2(data):
33
+ if data.ndim > 2: # has channels
34
+ out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
35
+ for c in range(data.shape[2]):
36
+ c_data = data[:, :, c]
37
+ out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
38
+ out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
39
+ else: # one channel
40
+ out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
41
+ out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
42
+ out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
43
+
44
+ return out_ifft
45
+
46
+ def _get_gaussian_window(width, height, std=3.14, mode=0):
47
+ window_scale_x = float(width / min(width, height))
48
+ window_scale_y = float(height / min(width, height))
49
+
50
+ window = np.zeros((width, height))
51
+ x = (np.arange(width) / width * 2. - 1.) * window_scale_x
52
+ for y in range(height):
53
+ fy = (y / height * 2. - 1.) * window_scale_y
54
+ if mode == 0:
55
+ window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
56
+ else:
57
+ window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian
58
+
59
+ return window
60
+
61
+ def _get_masked_window_rgb(np_mask_grey, hardness=1.):
62
+ np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
63
+ if hardness != 1.:
64
+ hardened = np_mask_grey[:] ** hardness
65
+ else:
66
+ hardened = np_mask_grey[:]
67
+ for c in range(3):
68
+ np_mask_rgb[:, :, c] = hardened[:]
69
+ return np_mask_rgb
70
+
71
+ width = _np_src_image.shape[0]
72
+ height = _np_src_image.shape[1]
73
+ num_channels = _np_src_image.shape[2]
74
+
75
+ _np_src_image[:] * (1. - np_mask_rgb)
76
+ np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
77
+ img_mask = np_mask_grey > 1e-6
78
+ ref_mask = np_mask_grey < 1e-3
79
+
80
+ windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
81
+ windowed_image /= np.max(windowed_image)
82
+ windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
83
+
84
+ src_fft = _fft2(windowed_image) # get feature statistics from masked src img
85
+ src_dist = np.absolute(src_fft)
86
+ src_phase = src_fft / src_dist
87
+
88
+ # create a generator with a static seed to make outpainting deterministic / only follow global seed
89
+ rng = np.random.default_rng(0)
90
+
91
+ noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
92
+ noise_rgb = rng.random((width, height, num_channels))
93
+ noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
94
+ noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
95
+ for c in range(num_channels):
96
+ noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
97
+
98
+ noise_fft = _fft2(noise_rgb)
99
+ for c in range(num_channels):
100
+ noise_fft[:, :, c] *= noise_window
101
+ noise_rgb = np.real(_ifft2(noise_fft))
102
+ shaped_noise_fft = _fft2(noise_rgb)
103
+ shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
104
+
105
+ brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
106
+ contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
107
+
108
+ # scikit-image is used for histogram matching, very convenient!
109
+ shaped_noise = np.real(_ifft2(shaped_noise_fft))
110
+ shaped_noise -= np.min(shaped_noise)
111
+ shaped_noise /= np.max(shaped_noise)
112
+ shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
113
+ shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
114
+
115
+ matched_noise = shaped_noise[:]
116
+
117
+ return np.clip(matched_noise, 0., 1.)
118
+
119
+
120
+
121
+ class Script(scripts.Script):
122
+ def title(self):
123
+ return "Outpainting mk2"
124
+
125
+ def show(self, is_img2img):
126
+ return is_img2img
127
+
128
+ def ui(self, is_img2img):
129
+ if not is_img2img:
130
+ return None
131
+
132
+ info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>")
133
+
134
+ pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
135
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
136
+ direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
137
+ noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
138
+ color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
139
+
140
+ return [info, pixels, mask_blur, direction, noise_q, color_variation]
141
+
142
+ def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
143
+ initial_seed_and_info = [None, None]
144
+
145
+ process_width = p.width
146
+ process_height = p.height
147
+
148
+ p.inpaint_full_res = False
149
+ p.inpainting_fill = 1
150
+ p.do_not_save_samples = True
151
+ p.do_not_save_grid = True
152
+
153
+ left = pixels if "left" in direction else 0
154
+ right = pixels if "right" in direction else 0
155
+ up = pixels if "up" in direction else 0
156
+ down = pixels if "down" in direction else 0
157
+
158
+ if left > 0 or right > 0:
159
+ mask_blur_x = mask_blur
160
+ else:
161
+ mask_blur_x = 0
162
+
163
+ if up > 0 or down > 0:
164
+ mask_blur_y = mask_blur
165
+ else:
166
+ mask_blur_y = 0
167
+
168
+ p.mask_blur_x = mask_blur_x*4
169
+ p.mask_blur_y = mask_blur_y*4
170
+
171
+ init_img = p.init_images[0]
172
+ target_w = math.ceil((init_img.width + left + right) / 64) * 64
173
+ target_h = math.ceil((init_img.height + up + down) / 64) * 64
174
+
175
+ if left > 0:
176
+ left = left * (target_w - init_img.width) // (left + right)
177
+
178
+ if right > 0:
179
+ right = target_w - init_img.width - left
180
+
181
+ if up > 0:
182
+ up = up * (target_h - init_img.height) // (up + down)
183
+
184
+ if down > 0:
185
+ down = target_h - init_img.height - up
186
+
187
+ def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
188
+ is_horiz = is_left or is_right
189
+ is_vert = is_top or is_bottom
190
+ pixels_horiz = expand_pixels if is_horiz else 0
191
+ pixels_vert = expand_pixels if is_vert else 0
192
+
193
+ images_to_process = []
194
+ output_images = []
195
+ for n in range(count):
196
+ res_w = init[n].width + pixels_horiz
197
+ res_h = init[n].height + pixels_vert
198
+ process_res_w = math.ceil(res_w / 64) * 64
199
+ process_res_h = math.ceil(res_h / 64) * 64
200
+
201
+ img = Image.new("RGB", (process_res_w, process_res_h))
202
+ img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
203
+ mask = Image.new("RGB", (process_res_w, process_res_h), "white")
204
+ draw = ImageDraw.Draw(mask)
205
+ draw.rectangle((
206
+ expand_pixels + mask_blur_x if is_left else 0,
207
+ expand_pixels + mask_blur_y if is_top else 0,
208
+ mask.width - expand_pixels - mask_blur_x if is_right else res_w,
209
+ mask.height - expand_pixels - mask_blur_y if is_bottom else res_h,
210
+ ), fill="black")
211
+
212
+ np_image = (np.asarray(img) / 255.0).astype(np.float64)
213
+ np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
214
+ noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
215
+ output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
216
+
217
+ target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
218
+ target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
219
+ p.width = target_width if is_horiz else img.width
220
+ p.height = target_height if is_vert else img.height
221
+
222
+ crop_region = (
223
+ 0 if is_left else output_images[n].width - target_width,
224
+ 0 if is_top else output_images[n].height - target_height,
225
+ target_width if is_left else output_images[n].width,
226
+ target_height if is_top else output_images[n].height,
227
+ )
228
+ mask = mask.crop(crop_region)
229
+ p.image_mask = mask
230
+
231
+ image_to_process = output_images[n].crop(crop_region)
232
+ images_to_process.append(image_to_process)
233
+
234
+ p.init_images = images_to_process
235
+
236
+ latent_mask = Image.new("RGB", (p.width, p.height), "white")
237
+ draw = ImageDraw.Draw(latent_mask)
238
+ draw.rectangle((
239
+ expand_pixels + mask_blur_x * 2 if is_left else 0,
240
+ expand_pixels + mask_blur_y * 2 if is_top else 0,
241
+ mask.width - expand_pixels - mask_blur_x * 2 if is_right else res_w,
242
+ mask.height - expand_pixels - mask_blur_y * 2 if is_bottom else res_h,
243
+ ), fill="black")
244
+ p.latent_mask = latent_mask
245
+
246
+ proc = process_images(p)
247
+
248
+ if initial_seed_and_info[0] is None:
249
+ initial_seed_and_info[0] = proc.seed
250
+ initial_seed_and_info[1] = proc.info
251
+
252
+ for n in range(count):
253
+ output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
254
+ output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
255
+
256
+ return output_images
257
+
258
+ batch_count = p.n_iter
259
+ batch_size = p.batch_size
260
+ p.n_iter = 1
261
+ state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
262
+ all_processed_images = []
263
+
264
+ for i in range(batch_count):
265
+ imgs = [init_img] * batch_size
266
+ state.job = f"Batch {i + 1} out of {batch_count}"
267
+
268
+ if left > 0:
269
+ imgs = expand(imgs, batch_size, left, is_left=True)
270
+ if right > 0:
271
+ imgs = expand(imgs, batch_size, right, is_right=True)
272
+ if up > 0:
273
+ imgs = expand(imgs, batch_size, up, is_top=True)
274
+ if down > 0:
275
+ imgs = expand(imgs, batch_size, down, is_bottom=True)
276
+
277
+ all_processed_images += imgs
278
+
279
+ all_images = all_processed_images
280
+
281
+ combined_grid_image = images.image_grid(all_processed_images)
282
+ unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
283
+ if opts.return_grid and not unwanted_grid_because_of_img_count:
284
+ all_images = [combined_grid_image] + all_processed_images
285
+
286
+ res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
287
+
288
+ if opts.samples_save:
289
+ for img in all_processed_images:
290
+ images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p)
291
+
292
+ if opts.grid_save and not unwanted_grid_because_of_img_count:
293
+ images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
294
+
295
+ return res
poor_mans_outpainting.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import modules.scripts as scripts
4
+ import gradio as gr
5
+ from PIL import Image, ImageDraw
6
+
7
+ from modules import images, devices
8
+ from modules.processing import Processed, process_images
9
+ from modules.shared import opts, state
10
+
11
+
12
+ class Script(scripts.Script):
13
+ def title(self):
14
+ return "Poor man's outpainting"
15
+
16
+ def show(self, is_img2img):
17
+ return is_img2img
18
+
19
+ def ui(self, is_img2img):
20
+ if not is_img2img:
21
+ return None
22
+
23
+ pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
24
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
25
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
26
+ direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
27
+
28
+ return [pixels, mask_blur, inpainting_fill, direction]
29
+
30
+ def run(self, p, pixels, mask_blur, inpainting_fill, direction):
31
+ initial_seed = None
32
+ initial_info = None
33
+
34
+ p.mask_blur = mask_blur * 2
35
+ p.inpainting_fill = inpainting_fill
36
+ p.inpaint_full_res = False
37
+
38
+ left = pixels if "left" in direction else 0
39
+ right = pixels if "right" in direction else 0
40
+ up = pixels if "up" in direction else 0
41
+ down = pixels if "down" in direction else 0
42
+
43
+ init_img = p.init_images[0]
44
+ target_w = math.ceil((init_img.width + left + right) / 64) * 64
45
+ target_h = math.ceil((init_img.height + up + down) / 64) * 64
46
+
47
+ if left > 0:
48
+ left = left * (target_w - init_img.width) // (left + right)
49
+ if right > 0:
50
+ right = target_w - init_img.width - left
51
+
52
+ if up > 0:
53
+ up = up * (target_h - init_img.height) // (up + down)
54
+
55
+ if down > 0:
56
+ down = target_h - init_img.height - up
57
+
58
+ img = Image.new("RGB", (target_w, target_h))
59
+ img.paste(init_img, (left, up))
60
+
61
+ mask = Image.new("L", (img.width, img.height), "white")
62
+ draw = ImageDraw.Draw(mask)
63
+ draw.rectangle((
64
+ left + (mask_blur * 2 if left > 0 else 0),
65
+ up + (mask_blur * 2 if up > 0 else 0),
66
+ mask.width - right - (mask_blur * 2 if right > 0 else 0),
67
+ mask.height - down - (mask_blur * 2 if down > 0 else 0)
68
+ ), fill="black")
69
+
70
+ latent_mask = Image.new("L", (img.width, img.height), "white")
71
+ latent_draw = ImageDraw.Draw(latent_mask)
72
+ latent_draw.rectangle((
73
+ left + (mask_blur//2 if left > 0 else 0),
74
+ up + (mask_blur//2 if up > 0 else 0),
75
+ mask.width - right - (mask_blur//2 if right > 0 else 0),
76
+ mask.height - down - (mask_blur//2 if down > 0 else 0)
77
+ ), fill="black")
78
+
79
+ devices.torch_gc()
80
+
81
+ grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
82
+ grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
83
+ grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
84
+
85
+ p.n_iter = 1
86
+ p.batch_size = 1
87
+ p.do_not_save_grid = True
88
+ p.do_not_save_samples = True
89
+
90
+ work = []
91
+ work_mask = []
92
+ work_latent_mask = []
93
+ work_results = []
94
+
95
+ for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
96
+ for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
97
+ x, w = tiledata[0:2]
98
+
99
+ if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
100
+ continue
101
+
102
+ work.append(tiledata[2])
103
+ work_mask.append(tiledata_mask[2])
104
+ work_latent_mask.append(tiledata_latent_mask[2])
105
+
106
+ batch_count = len(work)
107
+ print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.")
108
+
109
+ state.job_count = batch_count
110
+
111
+ for i in range(batch_count):
112
+ p.init_images = [work[i]]
113
+ p.image_mask = work_mask[i]
114
+ p.latent_mask = work_latent_mask[i]
115
+
116
+ state.job = f"Batch {i + 1} out of {batch_count}"
117
+ processed = process_images(p)
118
+
119
+ if initial_seed is None:
120
+ initial_seed = processed.seed
121
+ initial_info = processed.info
122
+
123
+ p.seed = processed.seed + 1
124
+ work_results += processed.images
125
+
126
+
127
+ image_index = 0
128
+ for y, h, row in grid.tiles:
129
+ for tiledata in row:
130
+ x, w = tiledata[0:2]
131
+
132
+ if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
133
+ continue
134
+
135
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
136
+ image_index += 1
137
+
138
+ combined_image = images.combine_grid(grid)
139
+
140
+ if opts.samples_save:
141
+ images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
142
+
143
+ processed = Processed(p, [combined_image], initial_seed, initial_info)
144
+
145
+ return processed
146
+