LituRout commited on
Commit
bb55bee
1 Parent(s): 614cf7f
Files changed (2) hide show
  1. app.py +546 -86
  2. requirements.txt +42 -5
app.py CHANGED
@@ -1,44 +1,381 @@
1
  import gradio as gr
2
- import numpy as np
3
- import torch
4
- from diffusers import StableDiffusionPipeline
5
- from transformers import CLIPTextModel, CLIPTokenizer
6
- from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
7
- from diffusers import LMSDiscreteScheduler
8
  from share_btn import community_icon_html, loading_icon_html
9
  from tqdm.auto import tqdm
 
 
 
 
 
 
10
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
12
 
13
- # PARAMS
14
- MANUAL_SEED = 42
15
- HEIGHT = 512
16
- WIDTH = 512
17
- ETA = 1e-1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
19
 
20
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
21
- torch_device = "cuda" if torch.cuda.is_available() else "cpu"
22
- pipe = pipe.to(torch_device)
23
 
24
- # 1. Load the autoencoder model which will be used to decode the latents into image space.
25
- vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
26
 
27
- # 2. Load the tokenizer and text encoder to tokenize and encode the text.
28
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
29
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
 
 
 
 
30
 
31
- # 3. The UNet model for generating the latents.
32
- unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
33
- scheduler = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
34
 
35
- vae = vae.to(torch_device)
36
- text_encoder = text_encoder.to(torch_device)
37
- unet = unet.to(torch_device)
 
38
 
 
 
 
 
 
 
39
 
40
- generator = torch.manual_seed(MANUAL_SEED) # Seed generator to create the inital latent noise
 
 
 
 
41
 
 
 
 
 
42
 
43
  def read_content(file_path: str) -> str:
44
  """read the content of target file
@@ -54,73 +391,183 @@ def read_content(file_path: str) -> str:
54
  # output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
55
  # return output.images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
56
 
57
- def predict(dict, prompt=""):
58
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
 
59
 
60
- max_length = text_input.input_ids.shape[-1]
61
- uncond_input = tokenizer(
62
- [""], padding="max_length", max_length=max_length, return_tensors="pt"
63
- )
64
- with torch.no_grad():
65
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
66
 
 
67
  init_image = dict["image"].convert("RGB").resize((512, 512))
 
68
  mask = dict["mask"].convert("RGB").resize((512, 512))
69
 
70
  # convert input image to array in [-1, 1]
71
- init_image = torch.tensor(2 * (np.asarray(init_image) / 255) - 1, device=torch_device)
72
- mask = torch.tensor((np.asarray(mask) / 255), device=torch_device)
 
 
 
 
73
  # add one dimension for the batch and bring channels first
74
  init_image = init_image.permute(2, 0, 1).unsqueeze(0)
75
  mask = mask.permute(2, 0, 1).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- latents = torch.randn(
78
- (1, unet.in_channels, HEIGHT // 8, WIDTH // 8),
79
- generator=generator,
80
- )
81
- latents = latents.to(torch_device)
82
-
83
- for i, t in enumerate(tqdm(scheduler.timesteps)):
84
- t = scheduler.timesteps[i]
85
- z_t = torch.clone(latents.detach())
86
- z_t.requires_grad = True
87
-
88
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
89
- latent_model_input = scheduler.scale_model_input(z_t, t)
90
-
91
-
92
- # predict the noise residual
93
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample
94
- # compute z_0 using tweedies's formula
95
- indx = scheduler.num_inference_steps - i - 1
96
- z_0 = (1/torch.sqrt(scheduler.alphas_cumprod[indx]))\
97
- *(z_t + (1-scheduler.alphas_cumprod[indx]) * noise_pred )
98
-
99
- # pass through the decoder
100
- z_0 = 1 / 0.18215 * z_0
101
- image_pred = vae.decode(z_0).sample
102
- # clip
103
- image_pred = torch.clamp(image_pred, min=-1.0, max=1.0)
104
- inpainted_image = (1 - mask) * init_image + mask * image_pred
105
- error_measurement = (1/2) * torch.linalg.norm((1-mask) * (init_image - image_pred))**2
106
- # TODO(giannisdaras): add LPIPS?
107
- error = error_measurement
108
- gradients = torch.autograd.grad(error, inputs=z_t)[0]
109
- # compute the previous noisy sample x_t -> x_t-1
110
- z_t_next = scheduler.step(noise_pred, t, z_t).prev_sample
111
-
112
- latents = z_t_next - ETA * gradients
113
 
114
- # scale and decode the image latents with vae
115
- latents = 1 / 0.18215 * latents
 
 
 
116
 
117
- with torch.no_grad():
118
- image = vae.decode(latents).sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- image = (image / 2 + 0.5).clamp(0, 1)
121
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
122
- images = (image * 255).round().astype("uint8")
123
- return images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
124
 
125
 
126
  css = '''
@@ -170,21 +617,34 @@ with image_blocks as demo:
170
  with gr.Row():
171
  with gr.Column():
172
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
 
 
 
 
 
 
 
173
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
174
- prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
175
  btn = gr.Button("Inpaint!").style(
176
  margin=False,
177
  rounded=(False, True, True, False),
178
  full_width=False,
179
  )
 
180
  with gr.Column():
181
- image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
182
  with gr.Group(elem_id="share-btn-container"):
183
  community_icon = gr.HTML(community_icon_html, visible=False)
184
  loading_icon = gr.HTML(loading_icon_html, visible=False)
 
 
 
 
 
185
 
186
- btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon])
187
-
188
-
189
 
190
- image_blocks.launch()
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
  from share_btn import community_icon_html, loading_icon_html
3
  from tqdm.auto import tqdm
4
+
5
+ import argparse, os, sys, glob
6
+ import cv2
7
+ import torch
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
  from PIL import Image
11
+ from tqdm import trange
12
+ from imwatermark import WatermarkEncoder
13
+ from itertools import islice
14
+ from einops import rearrange
15
+ from torchvision.utils import make_grid
16
+ import time
17
+ from pytorch_lightning import seed_everything
18
+ from torch import autocast
19
+ from contextlib import contextmanager, nullcontext
20
+
21
+ from ldm.util import instantiate_from_config
22
+ from ldm.models.diffusion.psld import DDIMSampler
23
+ from ldm.models.diffusion.plms import PLMSSampler
24
+ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
25
+
26
+ # from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
27
+ from transformers import AutoFeatureExtractor
28
+
29
+ ## lr
30
+ import torchvision
31
+ import pdb
32
+ os.environ['CUDA_VISIBLE_DEVICES']='1'
33
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
34
+ ##
35
+
36
+ # load safety model
37
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
38
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
39
+ # safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
40
+
41
+ def chunk(it, size):
42
+ it = iter(it)
43
+ return iter(lambda: tuple(islice(it, size)), ())
44
+
45
+
46
+ def numpy_to_pil(images):
47
+ """
48
+ Convert a numpy image or a batch of images to a PIL image.
49
+ """
50
+ if images.ndim == 3:
51
+ images = images[None, ...]
52
+ images = (images * 255).round().astype("uint8")
53
+ pil_images = [Image.fromarray(image) for image in images]
54
+
55
+ return pil_images
56
+
57
+
58
+ def load_model_from_config(config, ckpt, verbose=False):
59
+ print(f"Loading model from {ckpt}")
60
+ pl_sd = torch.load(ckpt, map_location="cpu")
61
+ if "global_step" in pl_sd:
62
+ print(f"Global Step: {pl_sd['global_step']}")
63
+ sd = pl_sd["state_dict"]
64
+ model = instantiate_from_config(config.model)
65
+ m, u = model.load_state_dict(sd, strict=False)
66
+ if len(m) > 0 and verbose:
67
+ print("missing keys:")
68
+ print(m)
69
+ if len(u) > 0 and verbose:
70
+ print("unexpected keys:")
71
+ print(u)
72
+
73
+ model.cuda()
74
+ model.eval()
75
+ return model
76
+
77
+
78
+ def put_watermark(img, wm_encoder=None):
79
+ if wm_encoder is not None:
80
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
81
+ img = wm_encoder.encode(img, 'dwtDct')
82
+ img = Image.fromarray(img[:, :, ::-1])
83
+ return img
84
+
85
+
86
+ def load_replacement(x):
87
+ try:
88
+ hwc = x.shape
89
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
90
+ y = (np.array(y)/255.0).astype(x.dtype)
91
+ assert y.shape == x.shape
92
+ return y
93
+ except Exception:
94
+ return x
95
+
96
+
97
+ def check_safety(x_image):
98
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
99
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
100
+ assert x_checked_image.shape[0] == len(has_nsfw_concept)
101
+ for i in range(len(has_nsfw_concept)):
102
+ if has_nsfw_concept[i]:
103
+ x_checked_image[i] = load_replacement(x_checked_image[i])
104
+ return x_checked_image, has_nsfw_concept
105
+
106
+
107
+ parser = argparse.ArgumentParser()
108
+
109
+ parser.add_argument(
110
+ "--prompt",
111
+ type=str,
112
+ nargs="?",
113
+ default="",
114
+ help="the prompt to render"
115
+ )
116
+ parser.add_argument(
117
+ "--outdir",
118
+ type=str,
119
+ nargs="?",
120
+ help="dir to write results to",
121
+ default="outputs/txt2img-samples"
122
+ )
123
+ parser.add_argument(
124
+ "--skip_grid",
125
+ action='store_false',
126
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
127
+ )
128
+ parser.add_argument(
129
+ "--skip_save",
130
+ action='store_true',
131
+ help="do not save individual samples. For speed measurements.",
132
+ )
133
+ parser.add_argument(
134
+ "--ddim_steps",
135
+ type=int,
136
+ default=200,
137
+ help="number of ddim sampling steps",
138
+ )
139
+ parser.add_argument(
140
+ "--plms",
141
+ action='store_true',
142
+ help="use plms sampling",
143
+ )
144
+ parser.add_argument(
145
+ "--dpm_solver",
146
+ action='store_true',
147
+ help="use dpm_solver sampling",
148
+ )
149
+ parser.add_argument(
150
+ "--laion400m",
151
+ action='store_true',
152
+ help="uses the LAION400M model",
153
+ )
154
+ parser.add_argument(
155
+ "--fixed_code",
156
+ action='store_true',
157
+ help="if enabled, uses the same starting code across samples ",
158
+ )
159
+ parser.add_argument(
160
+ "--ddim_eta",
161
+ type=float,
162
+ default=0.0,
163
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
164
+ )
165
+ parser.add_argument(
166
+ "--n_iter",
167
+ type=int,
168
+ default=1,
169
+ help="sample this often",
170
+ )
171
+ parser.add_argument(
172
+ "--H",
173
+ type=int,
174
+ default=512,
175
+ help="image height, in pixel space",
176
+ )
177
+ parser.add_argument(
178
+ "--W",
179
+ type=int,
180
+ default=512,
181
+ help="image width, in pixel space",
182
+ )
183
+ parser.add_argument(
184
+ "--C",
185
+ type=int,
186
+ default=4,
187
+ help="latent channels",
188
+ )
189
+ parser.add_argument(
190
+ "--f",
191
+ type=int,
192
+ default=8,
193
+ help="downsampling factor",
194
+ )
195
+ parser.add_argument(
196
+ "--n_samples",
197
+ type=int,
198
+ default=1,
199
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
200
+ )
201
+ parser.add_argument(
202
+ "--n_rows",
203
+ type=int,
204
+ default=0,
205
+ help="rows in the grid (default: n_samples)",
206
+ )
207
+ parser.add_argument(
208
+ "--scale",
209
+ type=float,
210
+ default=7.5,
211
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
212
+ )
213
+ parser.add_argument(
214
+ "--from-file",
215
+ type=str,
216
+ help="if specified, load prompts from this file",
217
+ )
218
+ parser.add_argument(
219
+ "--config",
220
+ type=str,
221
+ default="configs/stable-diffusion/v1-inference.yaml",
222
+ help="path to config which constructs model",
223
+ )
224
+ parser.add_argument(
225
+ "--ckpt",
226
+ type=str,
227
+ default="models/ldm/stable-diffusion-v1/model.ckpt",
228
+ help="path to checkpoint of model",
229
+ )
230
+ parser.add_argument(
231
+ "--seed",
232
+ type=int,
233
+ default=42,
234
+ help="the seed (for reproducible sampling)",
235
+ )
236
+ parser.add_argument(
237
+ "--precision",
238
+ type=str,
239
+ help="evaluate at this precision",
240
+ choices=["full", "autocast"],
241
+ default="autocast"
242
+ )
243
+ ##
244
+ parser.add_argument(
245
+ "--dps_path",
246
+ type=str,
247
+ default='diffusion-posterior-sampling/',
248
+ help="DPS codebase path",
249
+ )
250
+ parser.add_argument(
251
+ "--task_config",
252
+ type=str,
253
+ default='configs/inpainting_config.yaml',
254
+ help="task config yml file",
255
+ )
256
+ parser.add_argument(
257
+ "--diffusion_config",
258
+ type=str,
259
+ default='configs/diffusion_config.yaml',
260
+ help="diffusion config yml file",
261
+ )
262
+ parser.add_argument(
263
+ "--model_config",
264
+ type=str,
265
+ default='configs/model_config.yaml',
266
+ help="model config yml file",
267
+ )
268
+ parser.add_argument(
269
+ "--gamma",
270
+ type=float,
271
+ default=1e-1,
272
+ help="inpainting error",
273
+ )
274
+ parser.add_argument(
275
+ "--omega",
276
+ type=float,
277
+ default=1.0,
278
+ help="measurement error",
279
+ )
280
+ parser.add_argument(
281
+ "--inpainting",
282
+ type=int,
283
+ default=1,
284
+ help="inpainting",
285
+ )
286
+ parser.add_argument(
287
+ "--general_inverse",
288
+ type=int,
289
+ default=0,
290
+ help="general inverse",
291
+ )
292
+ parser.add_argument(
293
+ "--file_id",
294
+ type=str,
295
+ default='00014.png',
296
+ help='input image',
297
+ )
298
+ parser.add_argument(
299
+ "--skip_low_res",
300
+ action='store_true',
301
+ help='downsample result to 256',
302
+ )
303
+ parser.add_argument(
304
+ "--ffhq256",
305
+ action='store_true',
306
+ help='load SD weights trained on FFHQ',
307
+ )
308
+ parser.add_argument(
309
+ "--sd_path",
310
+ type=str,
311
+ default='stable-diffusion/',
312
+ help="SD codebase path",
313
+ )
314
+ ##
315
 
316
+ opt,_ = parser.parse_known_args()
317
+ # pdb.set_trace()
318
 
319
+ if opt.laion400m:
320
+ print("Falling back to LAION 400M model...")
321
+ opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
322
+ opt.ckpt = "models/ldm/text2img-large/model.ckpt"
323
+
324
+ ##
325
+ if opt.ffhq256:
326
+ print("Using FFHQ 256 finetuned model...")
327
+ opt.config = "models/ldm/ffhq256/config.yaml"
328
+ opt.ckpt = "models/ldm/ffhq256/model.ckpt"
329
+
330
+ sys.path.append(opt.sd_path)
331
+
332
+ opt.outdir = opt.sd_path+opt.outdir
333
+ opt.config = opt.sd_path+opt.config
334
+ opt.ckpt = opt.sd_path+opt.ckpt
335
+ ##
336
+
337
+ seed_everything(opt.seed)
338
 
339
+ pdb.set_trace()
340
 
341
+ config = OmegaConf.load(f"{opt.config}")
342
+ model = load_model_from_config(config, f"{opt.ckpt}")
 
343
 
344
+ model = model.to(device)
 
345
 
346
+ if opt.dpm_solver:
347
+ sampler = DPMSolverSampler(model)
348
+ elif opt.plms:
349
+ sampler = PLMSSampler(model)
350
+ else:
351
+ # pdb.set_trace()
352
+ sampler = DDIMSampler(model)
353
 
354
+ os.makedirs(opt.outdir, exist_ok=True)
355
+ outpath = opt.outdir
 
356
 
357
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
358
+ wm = "StableDiffusionV1"
359
+ wm_encoder = WatermarkEncoder()
360
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
361
 
362
+ batch_size = opt.n_samples
363
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
364
+ if not opt.from_file:
365
+ prompt = opt.prompt
366
+ assert prompt is not None
367
+ data = [batch_size * [prompt]]
368
 
369
+ else:
370
+ print(f"reading prompts from {opt.from_file}")
371
+ with open(opt.from_file, "r") as f:
372
+ data = f.read().splitlines()
373
+ data = list(chunk(data, batch_size))
374
 
375
+ sample_path = os.path.join(outpath, "samples")
376
+ os.makedirs(sample_path, exist_ok=True)
377
+ base_count = len(os.listdir(sample_path))
378
+ grid_count = len(os.listdir(outpath)) - 1
379
 
380
  def read_content(file_path: str) -> str:
381
  """read the content of target file
 
391
  # output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
392
  # return output.images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
393
 
394
+ #########################################################
395
+ # Sampler
396
+ #########################################################
397
 
398
+ def predict(ddim_steps, gamma, gluing_kernel_size, gluing_kernel_sigma, omega, dict, prompt=""):
399
+ opt.ddim_steps = ddim_steps
400
+ opt.gamma = gamma
401
+ opt.omega = omega
 
 
402
 
403
+ opt.prompt = prompt
404
  init_image = dict["image"].convert("RGB").resize((512, 512))
405
+ # pdb.set_trace()
406
  mask = dict["mask"].convert("RGB").resize((512, 512))
407
 
408
  # convert input image to array in [-1, 1]
409
+ init_image = torch.tensor(2 * (np.asarray(init_image) / 255) - 1, device=device)
410
+ mask = torch.tensor((np.asarray(mask) / 255), device=device)
411
+
412
+ init_image = init_image.type(torch.float32)
413
+ # mask = mask.type(torch.float32)
414
+
415
  # add one dimension for the batch and bring channels first
416
  init_image = init_image.permute(2, 0, 1).unsqueeze(0)
417
  mask = mask.permute(2, 0, 1).unsqueeze(0)
418
+ mask[mask>=0.5] = 1.0
419
+ mask[mask<0.5] = 0.0
420
+ mask = 1-mask
421
+ # check if the gadio takes the mask only or the masker image as arguments?
422
+
423
+
424
+
425
+ #########################################################
426
+ ## DPS configs
427
+ #########################################################
428
+ sys.path.append(opt.dps_path)
429
+
430
+ import yaml
431
+ from guided_diffusion.measurements import get_noise, get_operator
432
+ from util.img_utils import clear_color, mask_generator
433
+ import torch.nn.functional as f
434
+ import matplotlib.pyplot as plt
435
+
436
+
437
+ def load_yaml(file_path: str) -> dict:
438
+ with open(file_path) as f:
439
+ config = yaml.load(f, Loader=yaml.FullLoader)
440
+ return config
441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ model_config=opt.dps_path+opt.model_config
444
+ diffusion_config=opt.dps_path+opt.diffusion_config
445
+ task_config=opt.dps_path+opt.task_config
446
+
447
+ # pdb.set_trace()
448
 
449
+ # Load configurations
450
+ model_config = load_yaml(model_config)
451
+ diffusion_config = load_yaml(diffusion_config)
452
+ task_config = load_yaml(task_config)
453
+ task_config['measurement']['mask_opt']['image_size']=opt.H
454
+
455
+ # Prepare Operator and noise
456
+ measure_config = task_config['measurement']
457
+ operator = get_operator(device=device, **measure_config['operator'])
458
+ noiser = get_noise(**measure_config['noise'])
459
+
460
+ # Exception) In case of inpainting, we need to generate a mask
461
+ if measure_config['operator']['name'] == 'inpainting':
462
+ mask_gen = mask_generator(
463
+ **measure_config['mask_opt']
464
+ )
465
+ # print(init_image.shape)
466
+ # Exception) In case of inpainging,
467
+ if measure_config['operator'] ['name'] == 'inpainting':
468
+ dps_mask = mask_gen(init_image) # dps mask
469
+ # dps_mask = torch.ones_like(org_image) # no mask
470
+ dps_mask[:,0,:,:] = mask[:,0,:,:]
471
+ dps_mask = dps_mask[:, 0, :, :].unsqueeze(dim=0)
472
+ # Forward measurement model (Ax + n)
473
+ y = operator.forward(init_image, mask=dps_mask)
474
+ y_n = noiser(y)
475
+
476
+ else:
477
+ # Forward measurement model (Ax + n)
478
+ y = operator.forward(init_image)
479
+ y_n = noiser(y)
480
+ mask = None
481
+ #########################################################
482
+ # pdb.set_trace()
483
+ start_code = None
484
+ if opt.fixed_code:
485
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
486
+
487
+ precision_scope = autocast if opt.precision=="autocast" else nullcontext
488
+ with precision_scope("cuda"):
489
+ with model.ema_scope():
490
+ uc = None
491
+ if opt.ffhq256:
492
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
493
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
494
+ batch_size=opt.n_samples,
495
+ shape=shape,
496
+ verbose=False,
497
+ eta=opt.ddim_eta,
498
+ x_T=start_code,
499
+ ip_mask = mask,
500
+ measurements = y_n,
501
+ operator = operator,
502
+ gamma = opt.gamma,
503
+ inpainting = opt.inpainting,
504
+ omega = opt.omega,
505
+ general_inverse=opt.general_inverse,
506
+ noiser=noiser,
507
+ ffhq256=opt.ffhq256)
508
+ else:
509
+ # pdb.set_trace()
510
+ if opt.scale != 1.0 :
511
+ uc = model.get_learned_conditioning(batch_size * [""])
512
+ if isinstance(opt.prompt, tuple):
513
+ opt.prompt = list(opt.prompt)
514
+ c = model.get_learned_conditioning(opt.prompt)
515
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
516
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
517
+ conditioning=c,
518
+ batch_size=opt.n_samples,
519
+ shape=shape,
520
+ verbose=False,
521
+ unconditional_guidance_scale=opt.scale,
522
+ unconditional_conditioning=uc,
523
+ eta=opt.ddim_eta,
524
+ x_T=start_code,
525
+ ip_mask = mask,
526
+ measurements = y_n,
527
+ operator = operator,
528
+ gamma = opt.gamma,
529
+ inpainting = opt.inpainting,
530
+ omega = opt.omega,
531
+ general_inverse=opt.general_inverse,
532
+ noiser=noiser)
533
+
534
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
535
+ # pdb.set_trace()
536
+ # final step
537
+ if gluing_kernel_size > 0 and gluing_kernel_sigma > 0:
538
+ blur = torchvision.transforms.GaussianBlur(gluing_kernel_size, sigma=gluing_kernel_sigma)
539
+ mask = blur(mask)
540
+ x_samples_ddim = mask * init_image + (1-mask) * x_samples_ddim
541
+
542
+ x_samples_ddim1 = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
543
+ x_samples_ddim1 = x_samples_ddim1.cpu().permute(0, 2, 3, 1).numpy()
544
+ x_checked_image_torch = torch.from_numpy(x_samples_ddim1).permute(0, 3, 1, 2)
545
+ x_sample1 = 255. * rearrange(x_checked_image_torch[0].cpu().numpy(), 'c h w -> h w c')
546
+
547
+
548
+ ## no need to enc-dec again
549
+ encoded_z_0 = model.encode_first_stage(x_samples_ddim.float())
550
+ encoded_z_0 = model.get_first_stage_encoding(encoded_z_0)
551
+ x_samples_ddim = model.decode_first_stage(encoded_z_0)
552
+
553
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
554
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
555
+
556
+ # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
557
+ # x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
558
+
559
+ # pdb.set_trace()
560
+ x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2)
561
+
562
+ x_sample2 = 255. * rearrange(x_checked_image_torch[0].cpu().numpy(), 'c h w -> h w c')
563
+ # img = Image.fromarray(x_sample2.astype(np.uint8))
564
+ # img = put_watermark(img, wm_encoder)
565
+
566
+ image1 = x_sample1.astype("uint8")
567
+ image2 = x_sample2.astype("uint8")
568
+ # pdb.set_trace()
569
+ return image1, image2, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
570
 
 
 
 
 
571
 
572
 
573
  css = '''
 
617
  with gr.Row():
618
  with gr.Column():
619
  image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
620
+
621
+ ddim_steps = gr.Slider(minimum = 1, maximum = 1000, step = 1, label = 'Number of diffusion steps', default=200, interative=True)
622
+ gamma = gr.Slider(minimum = 0, maximum = 1, step=0.01, label = 'Gluing factor', default=1e-1, interative=True)
623
+ gluing_kernel_size = gr.Slider(minimum = 0, maximum = 100, step=1, label = 'Gluing kernel size', default=15, interative=True)
624
+ gluing_kernel_sigma = gr.Slider(minimum = 0, maximum = 25, step=1, label = 'Gluing kernel sigma', default=7, interative=True)
625
+ omega = gr.Slider(minimum = 0, maximum = 2, step=0.1, label = 'Measurement factor', default=1, interative=True)
626
+
627
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
628
+ prompt = gr.Textbox(placeholder = 'Your prompt (leave empty for posterior sampling)', show_label=False, elem_id="input-text")
629
  btn = gr.Button("Inpaint!").style(
630
  margin=False,
631
  rounded=(False, True, True, False),
632
  full_width=False,
633
  )
634
+
635
  with gr.Column():
636
+ image_out1 = gr.Image(label="Output1", elem_id="output-img-1").style(height=400)
637
  with gr.Group(elem_id="share-btn-container"):
638
  community_icon = gr.HTML(community_icon_html, visible=False)
639
  loading_icon = gr.HTML(loading_icon_html, visible=False)
640
+
641
+ image_out2 = gr.Image(label="Output2", elem_id="output-img-2").style(height=400)
642
+ with gr.Group(elem_id="share-btn-container"):
643
+ community_icon = gr.HTML(community_icon_html, visible=False)
644
+ loading_icon = gr.HTML(loading_icon_html, visible=False)
645
 
646
+ btn.click(fn=predict, inputs=[ddim_steps, gamma, gluing_kernel_size, gluing_kernel_sigma, omega, image, prompt], outputs=[image_out1, image_out2, community_icon, loading_icon])
 
 
647
 
648
+ image_blocks.queue()
649
+ image_blocks.launch(share=True)
650
+ # image_blocks.launch()
requirements.txt CHANGED
@@ -1,13 +1,50 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
- torch
3
- torchvision
4
  git+https://github.com/huggingface/diffusers.git
5
- transformers
6
  ftfy
7
- numpy
8
  matplotlib
9
  uuid
10
  opencv-python
 
11
  scipy
12
  accelerate
13
- git+https://github.com/openai/CLIP.git
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.11.0
3
+ torchvision==0.12.0
4
  git+https://github.com/huggingface/diffusers.git
 
5
  ftfy
6
+ numpy=1.19.2
7
  matplotlib
8
  uuid
9
  opencv-python
10
+ opencv-contrib
11
  scipy
12
  accelerate
13
+ git+https://github.com/openai/CLIP.git
14
+ certifi==2022.9.14
15
+ charset-normalizer==2.1.1
16
+ contourpy==1.0.5
17
+ cycler==0.11.0
18
+ fonttools==4.37.2
19
+ idna==3.4
20
+ kiwisolver==1.4.4
21
+ matplotlib==3.6.0
22
+ numpy==1.23.3
23
+ packaging==21.3
24
+ Pillow==9.2.0
25
+ pyparsing==3.0.9
26
+ python-dateutil==2.8.2
27
+ PyYAML==6.0
28
+ requests==2.28.1
29
+ scipy==1.9.1
30
+ six==1.16.0
31
+ tqdm==4.64.1
32
+ typing-extensions==4.3.0
33
+ urllib3==1.26.12
34
+ lbumentations==0.4.3
35
+ diffusers
36
+ opencv-python==4.1.2.30
37
+ pudb==2019.2
38
+ invisible-watermark
39
+ imageio==2.9.0
40
+ imageio-ffmpeg==0.4.2
41
+ pytorch-lightning==1.4.2
42
+ omegaconf==2.1.1
43
+ test-tube>=0.7.5
44
+ streamlit>=0.73.1
45
+ einops==0.3.0
46
+ torch-fidelity==0.3.0
47
+ transformers==4.19.2
48
+ torchmetrics==0.6.0
49
+ kornia==0.6
50
+ git+https://github.com/CompVis/taming-transformers.git