ZYMPKU commited on
Commit
8841787
1 Parent(s): 5418bad
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +4 -4
  2. demo/examples/{Peaceful_0_0.jpeg → CREEK_0.jpg} +2 -2
  3. demo/examples/{FAVOURITE_0_0.jpeg → Delivery_0.jpg} +2 -2
  4. demo/examples/{FRONTIER_0_0.png → HAPPEN_0.jpg} +2 -2
  5. demo/examples/{TREE_0_0.png → WORDS_0.jpg} +2 -2
  6. demo/examples/{better_0_0.jpg → better_0.jpg} +0 -0
  7. sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc +0 -0
  8. sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc +0 -0
  9. sgm/modules/diffusionmodules/guiders.py +2 -1
  10. sgm/modules/diffusionmodules/sampling.py +8 -43
  11. temp/attn_map/attn_map_1.png +0 -0
  12. temp/attn_map/attn_map_10.png +0 -0
  13. temp/attn_map/attn_map_11.png +0 -0
  14. temp/attn_map/attn_map_12.png +0 -0
  15. temp/attn_map/attn_map_13.png +0 -0
  16. temp/attn_map/attn_map_14.png +0 -0
  17. temp/attn_map/attn_map_15.png +0 -0
  18. temp/attn_map/attn_map_16.png +0 -0
  19. temp/attn_map/attn_map_17.png +0 -0
  20. temp/attn_map/attn_map_18.png +0 -0
  21. temp/attn_map/attn_map_19.png +0 -0
  22. temp/attn_map/attn_map_2.png +0 -0
  23. temp/attn_map/attn_map_20.png +0 -0
  24. temp/attn_map/attn_map_21.png +0 -0
  25. temp/attn_map/attn_map_22.png +0 -0
  26. temp/attn_map/attn_map_23.png +0 -0
  27. temp/attn_map/attn_map_24.png +0 -0
  28. temp/attn_map/attn_map_25.png +0 -0
  29. temp/attn_map/attn_map_26.png +0 -0
  30. temp/attn_map/attn_map_27.png +0 -0
  31. temp/attn_map/attn_map_28.png +0 -0
  32. temp/attn_map/attn_map_29.png +0 -0
  33. temp/attn_map/attn_map_3.png +0 -0
  34. temp/attn_map/attn_map_4.png +0 -0
  35. temp/attn_map/attn_map_5.png +0 -0
  36. temp/attn_map/attn_map_6.png +0 -0
  37. temp/attn_map/attn_map_7.png +0 -0
  38. temp/attn_map/attn_map_8.png +0 -0
  39. temp/attn_map/attn_map_9.png +0 -0
  40. temp/seg_map/seg_1.npy +3 -0
  41. temp/seg_map/seg_1.png +0 -0
  42. temp/seg_map/seg_10.npy +3 -0
  43. temp/seg_map/seg_11.npy +3 -0
  44. temp/seg_map/seg_12.npy +3 -0
  45. temp/seg_map/seg_13.npy +3 -0
  46. temp/seg_map/seg_14.npy +3 -0
  47. temp/seg_map/seg_15.npy +3 -0
  48. temp/seg_map/seg_16.npy +3 -0
  49. temp/seg_map/seg_17.npy +3 -0
  50. temp/seg_map/seg_18.npy +3 -0
app.py CHANGED
@@ -171,7 +171,7 @@ if __name__ == "__main__":
171
  model = init_model(cfgs)
172
  sampler = init_sampling(cfgs)
173
  global_index = 0
174
- resize = Resize((cfgs.H, cfgs.W))
175
 
176
  block = gr.Blocks().queue()
177
  with block:
@@ -202,7 +202,7 @@ if __name__ == "__main__":
202
  with gr.Column():
203
 
204
  input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
205
- gr.Markdown("Notice: please draw horizontally to indicate only **one** masked area.")
206
  text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
207
  run_button = gr.Button(variant="primary")
208
 
@@ -210,9 +210,9 @@ if __name__ == "__main__":
210
 
211
  num_samples = gr.Slider(label="Images", info="number of generated images, locked as 1", minimum=1, maximum=1, value=1, step=1)
212
  steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1)
213
- scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=5.0, step=0.1)
214
  seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True)
215
- show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=False)
216
 
217
  with gr.Column():
218
 
 
171
  model = init_model(cfgs)
172
  sampler = init_sampling(cfgs)
173
  global_index = 0
174
+ resize = Resize((cfgs.H, cfgs.W), antialias=True)
175
 
176
  block = gr.Blocks().queue()
177
  with block:
 
202
  with gr.Column():
203
 
204
  input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
205
+ gr.Markdown("Notice: please draw horizontally to indicate only **one** masked area. The image may be cropped automatically into a proper scale.")
206
  text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
207
  run_button = gr.Button(variant="primary")
208
 
 
210
 
211
  num_samples = gr.Slider(label="Images", info="number of generated images, locked as 1", minimum=1, maximum=1, value=1, step=1)
212
  steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1)
213
+ scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=4.0, step=0.1)
214
  seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True)
215
+ show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=True)
216
 
217
  with gr.Column():
218
 
demo/examples/{Peaceful_0_0.jpeg → CREEK_0.jpg} RENAMED
File without changes
demo/examples/{FAVOURITE_0_0.jpeg → Delivery_0.jpg} RENAMED
File without changes
demo/examples/{FRONTIER_0_0.png → HAPPEN_0.jpg} RENAMED
File without changes
demo/examples/{TREE_0_0.png → WORDS_0.jpg} RENAMED
File without changes
demo/examples/{better_0_0.jpg → better_0.jpg} RENAMED
File without changes
sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc CHANGED
Binary files a/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc and b/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc differ
 
sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc CHANGED
Binary files a/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc and b/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc differ
 
sgm/modules/diffusionmodules/guiders.py CHANGED
@@ -13,6 +13,7 @@ class VanillaCFG:
13
  def __init__(self, scale, dyn_thresh_config=None):
14
  scale_schedule = lambda scale, sigma: scale # independent of step
15
  self.scale_schedule = partial(scale_schedule, scale)
 
16
  self.dyn_thresh = instantiate_from_config(
17
  default(
18
  dyn_thresh_config,
@@ -24,7 +25,7 @@ class VanillaCFG:
24
 
25
  def __call__(self, x, sigma):
26
  x_u, x_c = x.chunk(2)
27
- scale_value = self.scale_schedule(sigma)
28
  x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29
  return x_pred
30
 
 
13
  def __init__(self, scale, dyn_thresh_config=None):
14
  scale_schedule = lambda scale, sigma: scale # independent of step
15
  self.scale_schedule = partial(scale_schedule, scale)
16
+ self.scale_value = scale
17
  self.dyn_thresh = instantiate_from_config(
18
  default(
19
  dyn_thresh_config,
 
25
 
26
  def __call__(self, x, sigma):
27
  x_u, x_c = x.chunk(2)
28
+ scale_value = self.scale_value
29
  x_pred = self.dyn_thresh(x_u, x_c, scale_value)
30
  return x_pred
31
 
sgm/modules/diffusionmodules/sampling.py CHANGED
@@ -7,7 +7,6 @@ from typing import Dict, Union
7
 
8
  import imageio
9
  import torch
10
- import json
11
  import numpy as np
12
  import torch.nn.functional as F
13
  from omegaconf import ListConfig, OmegaConf
@@ -252,47 +251,15 @@ class EulerEDMSampler(EDMSampler):
252
 
253
  return x
254
 
255
- def create_pascal_label_colormap(self):
256
- """
257
- PASCAL VOC 分割数据集的类别标签颜色映射label colormap
258
 
259
- 返回:
260
- 可视化分割结果的颜色映射Colormap
261
- """
262
- colormap = np.zeros((256, 3), dtype=int)
263
- ind = np.arange(256, dtype=int)
264
-
265
- for shift in reversed(range(8)):
266
- for channel in range(3):
267
- colormap[:, channel] |= ((ind >> channel) & 1) << shift
268
- ind >>= 3
269
-
270
- return colormap
271
-
272
- def save_segment_map(self, image, attn_maps, tokens=None, save_name=None):
273
-
274
- colormap = self.create_pascal_label_colormap()
275
- H, W = image.shape[-2:]
276
-
277
- image_ = image*0.3
278
  sections = []
279
  for i in range(len(tokens)):
280
  attn_map = attn_maps[i]
281
- attn_map_t = np.tile(attn_map[None], (1,3,1,1)) # b, 3, h, w
282
- attn_map_t = torch.from_numpy(attn_map_t)
283
- attn_map_t = F.interpolate(attn_map_t, (W, H))
284
-
285
- color = torch.from_numpy(colormap[i+1][None,:,None,None] / 255.0)
286
- colored_attn_map = attn_map_t * color
287
- colored_attn_map = colored_attn_map.to(device=image_.device)
288
-
289
- image_ += colored_attn_map*0.7
290
  sections.append(attn_map)
291
 
292
  section = np.stack(sections)
293
- np.save(f"temp/seg_map/seg_{save_name}.npy", section)
294
-
295
- save_image(image_, f"temp/seg_map/seg_{save_name}.png", normalize=True)
296
 
297
  def get_init_noise(self, cfgs, model, cond, batch, uc=None):
298
 
@@ -376,8 +343,7 @@ class EulerEDMSampler(EDMSampler):
376
  local_loss = torch.zeros(1)
377
  if save_attn:
378
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
379
- denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
380
- self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
381
 
382
  d = to_d(x, sigma_hat, denoised)
383
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
@@ -410,7 +376,7 @@ class EulerEDMSampler(EDMSampler):
410
 
411
  alpha = 20 * np.sqrt(scales[i])
412
  update = aae_enabled
413
- save_loss = detailed
414
  save_attn = detailed and (i == (num_sigmas-1)//2)
415
  save_inter = aae_enabled
416
 
@@ -452,7 +418,7 @@ class EulerEDMSampler(EDMSampler):
452
  imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
453
 
454
  return x
455
-
456
 
457
  class EulerEDMDualSampler(EulerEDMSampler):
458
 
@@ -557,9 +523,8 @@ class EulerEDMDualSampler(EulerEDMSampler):
557
  else:
558
  local_loss = torch.zeros(1)
559
  if save_attn:
560
- attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True)
561
- denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
562
- self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
563
 
564
  d = to_d(x, sigma_hat, denoised)
565
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
@@ -632,7 +597,7 @@ class EulerEDMDualSampler(EulerEDMSampler):
632
  print(f"Local losses: {local_losses}")
633
 
634
  if len(inters) > 0:
635
- imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.1)
636
 
637
  return x
638
 
 
7
 
8
  import imageio
9
  import torch
 
10
  import numpy as np
11
  import torch.nn.functional as F
12
  from omegaconf import ListConfig, OmegaConf
 
251
 
252
  return x
253
 
254
+ def save_segment_map(self, attn_maps, tokens=None, save_name=None):
 
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  sections = []
257
  for i in range(len(tokens)):
258
  attn_map = attn_maps[i]
 
 
 
 
 
 
 
 
 
259
  sections.append(attn_map)
260
 
261
  section = np.stack(sections)
262
+ np.save(f"./temp/seg_map/seg_{save_name}.npy", section)
 
 
263
 
264
  def get_init_noise(self, cfgs, model, cond, batch, uc=None):
265
 
 
343
  local_loss = torch.zeros(1)
344
  if save_attn:
345
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
346
+ self.save_segment_map(attn_map, tokens=batch["label"][0], save_name=name)
 
347
 
348
  d = to_d(x, sigma_hat, denoised)
349
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
 
376
 
377
  alpha = 20 * np.sqrt(scales[i])
378
  update = aae_enabled
379
+ save_loss = aae_enabled
380
  save_attn = detailed and (i == (num_sigmas-1)//2)
381
  save_inter = aae_enabled
382
 
 
418
  imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
419
 
420
  return x
421
+
422
 
423
  class EulerEDMDualSampler(EulerEDMSampler):
424
 
 
523
  else:
524
  local_loss = torch.zeros(1)
525
  if save_attn:
526
+ attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
527
+ self.save_segment_map(attn_map, tokens=batch["label"][0], save_name=name)
 
528
 
529
  d = to_d(x, sigma_hat, denoised)
530
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
 
597
  print(f"Local losses: {local_losses}")
598
 
599
  if len(inters) > 0:
600
+ imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
601
 
602
  return x
603
 
temp/attn_map/attn_map_1.png ADDED
temp/attn_map/attn_map_10.png ADDED
temp/attn_map/attn_map_11.png ADDED
temp/attn_map/attn_map_12.png ADDED
temp/attn_map/attn_map_13.png ADDED
temp/attn_map/attn_map_14.png ADDED
temp/attn_map/attn_map_15.png ADDED
temp/attn_map/attn_map_16.png ADDED
temp/attn_map/attn_map_17.png ADDED
temp/attn_map/attn_map_18.png ADDED
temp/attn_map/attn_map_19.png ADDED
temp/attn_map/attn_map_2.png ADDED
temp/attn_map/attn_map_20.png ADDED
temp/attn_map/attn_map_21.png ADDED
temp/attn_map/attn_map_22.png ADDED
temp/attn_map/attn_map_23.png ADDED
temp/attn_map/attn_map_24.png ADDED
temp/attn_map/attn_map_25.png ADDED
temp/attn_map/attn_map_26.png ADDED
temp/attn_map/attn_map_27.png ADDED
temp/attn_map/attn_map_28.png ADDED
temp/attn_map/attn_map_29.png ADDED
temp/attn_map/attn_map_3.png ADDED
temp/attn_map/attn_map_4.png ADDED
temp/attn_map/attn_map_5.png ADDED
temp/attn_map/attn_map_6.png ADDED
temp/attn_map/attn_map_7.png ADDED
temp/attn_map/attn_map_8.png ADDED
temp/attn_map/attn_map_9.png ADDED
temp/seg_map/seg_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fed32518482697a99ffa93f572123251ff1a1ce344e6c87108c8e6e5344428cb
3
+ size 32896
temp/seg_map/seg_1.png ADDED
temp/seg_map/seg_10.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ad403c856b5eb96cb1512d84fc5b030ea3fc0a043d3a80d3eeb37472cf343e
3
+ size 24704
temp/seg_map/seg_11.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2181e1e25621b1dcb55d2e59df2b168d786021936a99e6e954f0825474414714
3
+ size 24704
temp/seg_map/seg_12.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3ddb1524d276a3c1bdb101f9d6239adde117f441205ff44c325e96f11c2cef6
3
+ size 20608
temp/seg_map/seg_13.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2f93300e573de546ae8ef84db4f49010b31df98561bd658f8da017cf277cad3
3
+ size 24704
temp/seg_map/seg_14.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bc4905adb15d7bf5b7b902f7ed553bbf465e0e875ad376e3d75ddd1109ba800
3
+ size 24704
temp/seg_map/seg_15.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4f3815fe49562a34bf0de78eacc0c0508300d35b8edbe27bbe3a58f10779747
3
+ size 24704
temp/seg_map/seg_16.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c61b774d7ebf02ba7b6bfcda062cdf600f6d39f4d499b72d63bfa4ab566cb7f6
3
+ size 24704
temp/seg_map/seg_17.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86f445d11dc69ea7c45cd409637f630abb56859cedd5204cd761f1a75aece7e9
3
+ size 24704
temp/seg_map/seg_18.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19ef25b3f6cd2d76d62c96ccaec8f83c315eadbbba8c40cafb37c9a273d9cc8f
3
+ size 24704