sunshineatnoon commited on
Commit
bff9aca
1 Parent(s): f08b5ea
Files changed (12) hide show
  1. app.py +21 -1
  2. data/masks/12003_0_label.png +0 -0
  3. tmp/0.png +0 -0
  4. tmp/1.png +0 -0
  5. tmp/2.png +0 -0
  6. tmp/3.png +0 -0
  7. tmp/4.png +0 -0
  8. tmp/5.png +0 -0
  9. tmp/6.png +0 -0
  10. tmp/7.png +0 -0
  11. tmp/8.png +0 -0
  12. tmp/9.png +0 -0
app.py CHANGED
@@ -230,8 +230,28 @@ class Tester(TesterBase):
230
  key=2
231
  )
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  st.markdown('<p class="big-font">Choose the texture segment for each group in the given mask below.</p>', unsafe_allow_html=True)
234
- given_mask = Image.open('data/masks/124084_0_label.png').convert("L")
235
  given_mask = np.asarray(given_mask)
236
  given_mask = torch.from_numpy(given_mask)
237
  H, W = given_mask.shape[0], given_mask.shape[1]
 
230
  key=2
231
  )
232
 
233
+ st.markdown('<p class="big-font">Choose one mask for texture editing.</p>', unsafe_allow_html=True)
234
+ mask_list = glob(os.path.join("data/masks/*.png"))
235
+ byte_mask_list = []
236
+ for img_path in mask_list:
237
+ seg = Image.open(img_path).convert("L")
238
+ seg = np.asarray(seg)
239
+ seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1])
240
+ color_vq = self.draw_color_seg(seg)
241
+ vutils.save_image(color_vq, 'tmp/tmp.png')
242
+ with open('tmp/tmp.png', "rb") as image:
243
+ encoded = base64.b64encode(image.read()).decode()
244
+ byte_mask_list.append(f"data:image/jpeg;base64,{encoded}")
245
+ img_idx = clickable_images(
246
+ byte_mask_list,
247
+ titles=[f"Group #{str(i)}" for i in range(len(byte_mask_list))],
248
+ div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
249
+ img_style={"margin": "5px", "height": "150px"},
250
+ )
251
+ mask_path = mask_list[img_idx]
252
+
253
  st.markdown('<p class="big-font">Choose the texture segment for each group in the given mask below.</p>', unsafe_allow_html=True)
254
+ given_mask = Image.open(mask_path).convert("L")
255
  given_mask = np.asarray(given_mask)
256
  given_mask = torch.from_numpy(given_mask)
257
  H, W = given_mask.shape[0], given_mask.shape[1]
data/masks/12003_0_label.png ADDED
tmp/0.png DELETED
Binary file (3.55 kB)
 
tmp/1.png DELETED
Binary file (5.52 kB)
 
tmp/2.png DELETED
Binary file (14.7 kB)
 
tmp/3.png DELETED
Binary file (10.3 kB)
 
tmp/4.png DELETED
Binary file (15.7 kB)
 
tmp/5.png DELETED
Binary file (20.3 kB)
 
tmp/6.png DELETED
Binary file (4.79 kB)
 
tmp/7.png DELETED
Binary file (20.2 kB)
 
tmp/8.png DELETED
Binary file (9.19 kB)
 
tmp/9.png DELETED
Binary file (25.9 kB)