afeng commited on
Commit
a900192
1 Parent(s): f4f90db
Files changed (3) hide show
  1. app copy 2.py +385 -0
  2. app.py +53 -26
  3. segment.py +3 -2
app copy 2.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import copy
4
+ from PIL import Image
5
+ import matplotlib
6
+ import numpy as np
7
+ import gradio as gr
8
+ from utils import load_mask, load_mask_edit
9
+ from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
10
+ from pathlib import Path
11
+ import subprocess
12
+ from PIL import Image
13
+ from functools import partial
14
+ from main import run_main
15
+ LENGTH=512 #length of the square area displaying/editing images
16
+ TRANSPARENCY = 150 # transparency of the mask in display
17
+
18
+ def add_mask(mask_np_list_updated, mask_label_list):
19
+ mask_new = np.zeros_like(mask_np_list_updated[0])
20
+ mask_np_list_updated.append(mask_new)
21
+ mask_label_list.append("new")
22
+ return mask_np_list_updated, mask_label_list
23
+
24
+ def create_segmentation(mask_np_list):
25
+ viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
26
+ segmentation = 0
27
+ for i, m in enumerate(mask_np_list):
28
+ color = matplotlib.colors.to_rgb(viridis(i))
29
+ color_mat = np.ones_like(m)
30
+ color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
31
+ color_mat = color_mat * m[:,:,np.newaxis]
32
+ segmentation += color_mat
33
+ segmentation = Image.fromarray(np.uint8(segmentation*255))
34
+ return segmentation
35
+
36
+ def load_mask_ui(input_folder="example_tmp",load_edit = False):
37
+ if not load_edit:
38
+ mask_list, mask_label_list = load_mask(input_folder)
39
+ else:
40
+ mask_list, mask_label_list = load_mask_edit(input_folder)
41
+
42
+ mask_np_list = []
43
+ for m in mask_list:
44
+ mask_np_list. append( m.cpu().numpy())
45
+
46
+ return mask_np_list, mask_label_list
47
+
48
+ def load_image_ui(load_edit, input_folder="example_tmp"):
49
+ try:
50
+ for img_path in Path(input_folder).iterdir():
51
+ if img_path.name in ["img_512.png"]:
52
+ image = Image.open(img_path)
53
+ mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
54
+ image = image.convert('RGB')
55
+ segmentation = create_segmentation(mask_np_list)
56
+ print("!!", len(mask_np_list))
57
+ return image, segmentation, mask_np_list, mask_label_list, image
58
+ except:
59
+ print("Image folder invalid: The folder should contain image.png")
60
+ return None, None, None, None, None
61
+
62
+ def run_edit_text(
63
+ num_tokens,
64
+ num_sampling_steps,
65
+ strength,
66
+ edge_thickness,
67
+ tgt_prompt,
68
+ tgt_idx,
69
+ guidance_scale,
70
+ input_folder="example_tmp"
71
+ ):
72
+ subprocess.run(["python",
73
+ "main.py" ,
74
+ "--text",
75
+ "--name={}".format(input_folder),
76
+ "--dpm={}".format("sd"),
77
+ "--resolution={}".format(512),
78
+ "--load_trained",
79
+ "--num_tokens={}".format(num_tokens),
80
+ "--seed={}".format(2024),
81
+ "--guidance_scale={}".format(guidance_scale),
82
+ "--num_sampling_step={}".format(num_sampling_steps),
83
+ "--strength={}".format(strength),
84
+ "--edge_thickness={}".format(edge_thickness),
85
+ "--num_imgs={}".format(2),
86
+ "--tgt_prompt={}".format(tgt_prompt) ,
87
+ "--tgt_index={}".format(tgt_idx)
88
+ ])
89
+
90
+ return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
91
+
92
+
93
+ def run_optimization(
94
+ num_tokens,
95
+ embedding_learning_rate,
96
+ max_emb_train_steps,
97
+ diffusion_model_learning_rate,
98
+ max_diffusion_train_steps,
99
+ train_batch_size,
100
+ gradient_accumulation_steps,
101
+ input_folder = "example_tmp"
102
+ ):
103
+ subprocess.run(["python",
104
+ "main.py" ,
105
+ "--name={}".format(input_folder),
106
+ "--dpm={}".format("sd"),
107
+ "--resolution={}".format(512),
108
+ "--num_tokens={}".format(num_tokens),
109
+ "--embedding_learning_rate={}".format(embedding_learning_rate),
110
+ "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
111
+ "--max_emb_train_steps={}".format(max_emb_train_steps),
112
+ "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
113
+ "--train_batch_size={}".format(train_batch_size),
114
+ "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
115
+
116
+ ])
117
+ return
118
+
119
+
120
+ def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
121
+ backimg_solid_np = np.array(backimg)
122
+ bimg = backimg.copy()
123
+ fimg = foreimg.copy()
124
+ fimg.putalpha(transparency)
125
+ bimg.paste(fimg, (0,0), fimg)
126
+
127
+ bimg_np = np.array(bimg)
128
+ mask_np = mask_np[:,:,np.newaxis]
129
+
130
+ try:
131
+ new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
132
+ return Image.fromarray(new_img_np)
133
+ except:
134
+ import pdb; pdb.set_trace()
135
+
136
+ def show_segmentation(image, segmentation, flag):
137
+ if flag is False:
138
+ flag = True
139
+ mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
140
+ image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
141
+ return image_edit, flag
142
+ else:
143
+ flag = False
144
+ return image,flag
145
+
146
+ def edit_mask_add(canvas, image, idx, mask_np_list):
147
+ mask_sel = mask_np_list[idx]
148
+ mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
149
+ mask_np_list_updated = []
150
+ for midx, m in enumerate(mask_np_list):
151
+ if midx == idx:
152
+ mask_np_list_updated.append(mask_union(mask_sel, mask_new))
153
+ else:
154
+ mask_np_list_updated.append(m)
155
+
156
+ priority_list = [0 for _ in range(len(mask_np_list_updated))]
157
+ priority_list[idx] = 1
158
+ mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
159
+ mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
160
+ segmentation = create_segmentation(mask_np_list_updated)
161
+ image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
162
+ return mask_np_list_updated, image_edit
163
+
164
+ def slider_release(index, image, mask_np_list_updated, mask_label_list):
165
+
166
+ if index > len(mask_np_list_updated):
167
+ return image, "out of range"
168
+ else:
169
+ mask_np = mask_np_list_updated[index]
170
+ mask_label = mask_label_list[index]
171
+ segmentation = create_segmentation(mask_np_list_updated)
172
+ new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
173
+ return new_image, mask_label
174
+
175
+ def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
176
+ try:
177
+ assert np.all(sum(mask_np_list_updated)==1)
178
+ except:
179
+ print("please check mask")
180
+ # plt.imsave( "out_mask.png", mask_list_edit[0])
181
+ import pdb; pdb.set_trace()
182
+
183
+ for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
184
+ # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
185
+ np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
186
+ savepath = os.path.join(input_folder, "seg_current.png")
187
+ visualize_mask_list_clean(mask_np_list_updated, savepath)
188
+
189
+ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
190
+ try:
191
+ assert np.all(sum(mask_np_list_updated)==1)
192
+ except:
193
+ print("please check mask")
194
+ # plt.imsave( "out_mask.png", mask_list_edit[0])
195
+ import pdb; pdb.set_trace()
196
+ for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
197
+ np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
198
+ savepath = os.path.join(input_folder, "seg_edited.png")
199
+ visualize_mask_list_clean(mask_np_list_updated, savepath)
200
+
201
+
202
+ import shutil
203
+ if os.path.isdir("./example_tmp"):
204
+ shutil.rmtree("./example_tmp")
205
+
206
+ from segment import run_segmentation
207
+ with gr.Blocks() as demo:
208
+ image = gr.State() # store mask
209
+ image_loaded = gr.State()
210
+ segmentation = gr.State()
211
+
212
+ mask_np_list = gr.State([])
213
+ mask_label_list = gr.State([])
214
+ mask_np_list_updated = gr.State([])
215
+ true = gr.State(True)
216
+ false = gr.State(False)
217
+
218
+ with gr.Row():
219
+ gr.Markdown("""# D-Edit""")
220
+
221
+ with gr.Tab(label="1 Edit mask"):
222
+ with gr.Row():
223
+ with gr.Column():
224
+ canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
225
+
226
+ segment_button = gr.Button("1.1 Run segmentation")
227
+ segment_button.click(run_segmentation,
228
+ [canvas] ,
229
+ [] )
230
+
231
+ text_button = gr.Button("1.2 Load original masks")
232
+ text_button.click(load_image_ui,
233
+ [ false] ,
234
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
235
+
236
+ load_edit_button = gr.Button("1.2 Load edited masks")
237
+ load_edit_button.click(load_image_ui,
238
+ [ true] ,
239
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
240
+
241
+ show_segment = gr.Checkbox(label = "Show Segmentation")
242
+ flag = gr.State(False)
243
+ show_segment.select(show_segmentation,
244
+ [image_loaded, segmentation, flag],
245
+ [canvas, flag])
246
+
247
+ # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
248
+ mask_np_list_updated = mask_np_list
249
+ with gr.Column():
250
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
251
+ slider = gr.Slider(0, 20, step=1, interactive=True)
252
+ label = gr.Textbox()
253
+ slider.release(slider_release,
254
+ inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
255
+ outputs= [canvas, label]
256
+ )
257
+ add_button = gr.Button("Add")
258
+ add_button.click( edit_mask_add,
259
+ [canvas, image_loaded, slider, mask_np_list_updated] ,
260
+ [mask_np_list_updated, canvas]
261
+ )
262
+
263
+ save_button2 = gr.Button("Set and Save as edited masks")
264
+ save_button2.click( save_as_edit_mask,
265
+ [mask_np_list_updated, mask_label_list] ,
266
+ [] )
267
+
268
+ save_button = gr.Button("Set and Save as original masks")
269
+ save_button.click( save_as_orig_mask,
270
+ [mask_np_list_updated, mask_label_list] ,
271
+ [] )
272
+
273
+ back_button = gr.Button("Back to current seg")
274
+ back_button.click( load_mask_ui,
275
+ [] ,
276
+ [ mask_np_list_updated,mask_label_list] )
277
+
278
+ add_mask_button = gr.Button("Add new empty mask")
279
+ add_mask_button.click(add_mask,
280
+ [mask_np_list_updated, mask_label_list] ,
281
+ [mask_np_list_updated, mask_label_list] )
282
+
283
+ with gr.Tab(label="2 Optimization"):
284
+ with gr.Row():
285
+
286
+ with gr.Column():
287
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
288
+ num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
289
+ embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
290
+ max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
291
+
292
+ diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True )
293
+ max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True )
294
+
295
+ train_batch_size = gr.Number(value="5", label="Batch size", interactive= True )
296
+ gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
297
+
298
+ add_button = gr.Button("Run optimization")
299
+ def run_optimization_wrapper (
300
+ num_tokens,
301
+ embedding_learning_rate ,
302
+ max_emb_train_steps ,
303
+ diffusion_model_learning_rate ,
304
+ max_diffusion_train_steps,
305
+ train_batch_size,
306
+ gradient_accumulation_steps
307
+ ):
308
+ run_optimization = partial(
309
+ run_main,
310
+ num_tokens=int(num_tokens),
311
+ embedding_learning_rate = float(embedding_learning_rate),
312
+ max_emb_train_steps = int(max_emb_train_steps),
313
+ diffusion_model_learning_rate= float(diffusion_model_learning_rate),
314
+ max_diffusion_train_steps = int(max_diffusion_train_steps),
315
+ train_batch_size=int(train_batch_size),
316
+ gradient_accumulation_steps=int(gradient_accumulation_steps)
317
+ )
318
+ run_optimization()
319
+
320
+ add_button.click(run_optimization_wrapper,
321
+ inputs = [
322
+ num_tokens,
323
+ embedding_learning_rate ,
324
+ max_emb_train_steps ,
325
+ diffusion_model_learning_rate ,
326
+ max_diffusion_train_steps,
327
+ train_batch_size,
328
+ gradient_accumulation_steps
329
+ ],
330
+ outputs = []
331
+ )
332
+
333
+
334
+ with gr.Tab(label="3 Editing"):
335
+ with gr.Tab(label="3.1 Text-based editing"):
336
+
337
+ with gr.Row():
338
+ with gr.Column():
339
+ canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True)
340
+ # canvas_text_edit = gr.Gallery(label = "Edited results")
341
+
342
+ with gr.Column():
343
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
344
+
345
+ tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
346
+ tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
347
+ guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
348
+ num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
349
+ edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
350
+ strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
351
+
352
+ add_button = gr.Button("Run Editing")
353
+ run_edit_text = partial(
354
+ run_main,
355
+ load_trained=True,
356
+ text=True,
357
+ num_tokens = int(num_tokens.value),
358
+ guidance_scale = float(guidance_scale.value),
359
+ num_sampling_steps = int(num_sampling_steps.value),
360
+ strength = float(strength.value),
361
+ edge_thickness = int(edge_thickness.value),
362
+ num_imgs = 1,
363
+ tgt_prompt = tgt_prompt.value,
364
+ tgt_index = int(tgt_index.value)
365
+ )
366
+
367
+ add_button.click(run_edit_text,
368
+ inputs = [],
369
+ outputs = [canvas_text_edit]
370
+ )
371
+
372
+ def load_pil_img():
373
+ from PIL import Image
374
+ return Image.open("example_tmp/text/out_text_0.png")
375
+
376
+ load_button = gr.Button("Load results")
377
+ load_button.click(load_pil_img,
378
+ inputs = [],
379
+ outputs = [canvas_text_edit]
380
+ )
381
+
382
+
383
+
384
+
385
+ demo.queue().launch(share=True, debug=True)
app.py CHANGED
@@ -214,7 +214,7 @@ with gr.Blocks() as demo:
214
  mask_np_list_updated = gr.State([])
215
  true = gr.State(True)
216
  false = gr.State(False)
217
-
218
  with gr.Row():
219
  gr.Markdown("""# D-Edit""")
220
 
@@ -225,29 +225,33 @@ with gr.Blocks() as demo:
225
 
226
  segment_button = gr.Button("1.1 Run segmentation")
227
  segment_button.click(run_segmentation,
228
- [canvas] ,
229
- [] )
230
-
231
- text_button = gr.Button("1.2 Load original masks")
232
  text_button.click(load_image_ui,
233
  [ false] ,
234
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
235
 
236
- load_edit_button = gr.Button("1.2 Load edited masks")
237
  load_edit_button.click(load_image_ui,
238
  [ true] ,
239
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
240
 
241
- show_segment = gr.Checkbox(label = "Show Segmentation")
242
  flag = gr.State(False)
243
  show_segment.select(show_segmentation,
244
  [image_loaded, segmentation, flag],
245
  [canvas, flag])
246
-
 
 
 
 
247
  # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
248
  mask_np_list_updated = mask_np_list
249
  with gr.Column():
250
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
251
  slider = gr.Slider(0, 20, step=1, interactive=True)
252
  label = gr.Textbox()
253
  slider.release(slider_release,
@@ -282,8 +286,11 @@ with gr.Blocks() as demo:
282
 
283
  with gr.Tab(label="2 Optimization"):
284
  with gr.Row():
285
-
286
  with gr.Column():
 
 
 
 
287
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
288
  num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
289
  embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
@@ -296,7 +303,8 @@ with gr.Blocks() as demo:
296
  gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
297
 
298
  add_button = gr.Button("Run optimization")
299
- def run_optimization_wrapper (
 
300
  num_tokens,
301
  embedding_learning_rate ,
302
  max_emb_train_steps ,
@@ -316,9 +324,11 @@ with gr.Blocks() as demo:
316
  gradient_accumulation_steps=int(gradient_accumulation_steps)
317
  )
318
  run_optimization()
 
319
 
320
  add_button.click(run_optimization_wrapper,
321
  inputs = [
 
322
  num_tokens,
323
  embedding_learning_rate ,
324
  max_emb_train_steps ,
@@ -327,9 +337,15 @@ with gr.Blocks() as demo:
327
  train_batch_size,
328
  gradient_accumulation_steps
329
  ],
330
- outputs = []
331
  )
332
-
 
 
 
 
 
 
333
 
334
  with gr.Tab(label="3 Editing"):
335
  with gr.Tab(label="3.1 Text-based editing"):
@@ -350,19 +366,30 @@ with gr.Blocks() as demo:
350
  strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
351
 
352
  add_button = gr.Button("Run Editing")
353
- run_edit_text = partial(
354
- run_main,
355
- load_trained=True,
356
- text=True,
357
- num_tokens = int(num_tokens.value),
358
- guidance_scale = float(guidance_scale.value),
359
- num_sampling_steps = int(num_sampling_steps.value),
360
- strength = float(strength.value),
361
- edge_thickness = int(edge_thickness.value),
362
- num_imgs = 1,
363
- tgt_prompt = tgt_prompt.value,
364
- tgt_index = int(tgt_index.value)
365
- )
 
 
 
 
 
 
 
 
 
 
 
366
 
367
  add_button.click(run_edit_text,
368
  inputs = [],
 
214
  mask_np_list_updated = gr.State([])
215
  true = gr.State(True)
216
  false = gr.State(False)
217
+ block_flag = gr.State(0)
218
  with gr.Row():
219
  gr.Markdown("""# D-Edit""")
220
 
 
225
 
226
  segment_button = gr.Button("1.1 Run segmentation")
227
  segment_button.click(run_segmentation,
228
+ [canvas, block_flag] ,
229
+ [block_flag] )
230
+
231
+ text_button = gr.Button("Waiting 1.1 to complete")
232
  text_button.click(load_image_ui,
233
  [ false] ,
234
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
235
 
236
+ load_edit_button = gr.Button("Waiting 1.1 to complete")
237
  load_edit_button.click(load_image_ui,
238
  [ true] ,
239
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
240
 
241
+ show_segment = gr.Checkbox(label = "Waiting 1.1 to complete")
242
  flag = gr.State(False)
243
  show_segment.select(show_segmentation,
244
  [image_loaded, segmentation, flag],
245
  [canvas, flag])
246
+ def show_more_buttons():
247
+ return gr.Button("1.2 Load original masks"), gr.Button("1.2 Load edited masks") , gr.Checkbox(label = "Show Segmentation")
248
+ block_flag.change(show_more_buttons, [], [text_button,load_edit_button,show_segment ])
249
+
250
+
251
  # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
252
  mask_np_list_updated = mask_np_list
253
  with gr.Column():
254
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""")
255
  slider = gr.Slider(0, 20, step=1, interactive=True)
256
  label = gr.Textbox()
257
  slider.release(slider_release,
 
286
 
287
  with gr.Tab(label="2 Optimization"):
288
  with gr.Row():
 
289
  with gr.Column():
290
+
291
+ txt_box = gr.Textbox("Click to start optimization...", interactive = False)
292
+
293
+ opt_flag = gr.State(0)
294
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
295
  num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
296
  embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
 
303
  gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
304
 
305
  add_button = gr.Button("Run optimization")
306
+ def run_optimization_wrapper (
307
+ opt_flag,
308
  num_tokens,
309
  embedding_learning_rate ,
310
  max_emb_train_steps ,
 
324
  gradient_accumulation_steps=int(gradient_accumulation_steps)
325
  )
326
  run_optimization()
327
+ return opt_flag+1
328
 
329
  add_button.click(run_optimization_wrapper,
330
  inputs = [
331
+ opt_flag,
332
  num_tokens,
333
  embedding_learning_rate ,
334
  max_emb_train_steps ,
 
337
  train_batch_size,
338
  gradient_accumulation_steps
339
  ],
340
+ outputs = [opt_flag]
341
  )
342
+
343
+ def change_text(txt_box):
344
+ return gr.Textbox("Optimization Finished!", interactive = False)
345
+ def change_text2(txt_box):
346
+ return gr.Textbox("Start optimization, check logs for progress...", interactive = False)
347
+ add_button.click(change_text2, txt_box, txt_box)
348
+ opt_flag.change(change_text, txt_box, txt_box)
349
 
350
  with gr.Tab(label="3 Editing"):
351
  with gr.Tab(label="3.1 Text-based editing"):
 
366
  strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
367
 
368
  add_button = gr.Button("Run Editing")
369
+ def run_edit_text_wrapper(
370
+ num_tokens,
371
+ guidance_scale,
372
+ num_sampling_steps ,
373
+ strength ,
374
+ edge_thickness,
375
+ tgt_prompt ,
376
+ tgt_index
377
+ ):
378
+
379
+ run_edit_text = partial(
380
+ run_main,
381
+ load_trained=True,
382
+ text=True,
383
+ num_tokens = int(num_tokens),
384
+ guidance_scale = float(guidance_scale),
385
+ num_sampling_steps = int(num_sampling_steps),
386
+ strength = float(strength),
387
+ edge_thickness = int(edge_thickness),
388
+ num_imgs = 1,
389
+ tgt_prompt = tgt_prompt,
390
+ tgt_index = int(tgt_index)
391
+ )
392
+ return run_edit_text()
393
 
394
  add_button.click(run_edit_text,
395
  inputs = [],
segment.py CHANGED
@@ -89,7 +89,7 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
89
 
90
 
91
 
92
- def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
93
 
94
  base_folder_path = "."
95
 
@@ -115,4 +115,5 @@ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
115
  os.makedirs(save_folder, exist_ok=True)
116
  draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
117
  print("Finish segment")
118
- return
 
 
89
 
90
 
91
 
92
+ def run_segmentation(image, block_flag, name="example_tmp", size = 512, noseg=False):
93
 
94
  base_folder_path = "."
95
 
 
115
  os.makedirs(save_folder, exist_ok=True)
116
  draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
117
  print("Finish segment")
118
+ block_flag += 1
119
+ return block_flag