liuyizhang commited on
Commit
2e4e1c8
1 Parent(s): 81fed1b

update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -209,7 +209,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
209
 
210
  size = image_pil.size
211
 
212
- if task_type == 'seg' or task_type == 'inpainting':
213
  # initialize SAM
214
  predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
215
  image = np.array(image_path)
@@ -233,7 +233,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
233
 
234
  # masks: [1, 1, 512, 512]
235
 
236
- if task_type == 'det':
237
  pred_dict = {
238
  "boxes": boxes_filt,
239
  "size": [size[1], size[0]], # H,W
@@ -245,7 +245,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
245
  image_with_box.save(image_path)
246
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
247
  return image_result
248
- elif task_type == 'seg':
249
  assert sam_checkpoint, 'sam_checkpoint is not found!'
250
 
251
  # draw output image
@@ -302,8 +302,8 @@ if __name__ == "__main__":
302
  with gr.Column():
303
  input_image = gr.Image(source='upload', type="pil")
304
  task_type = gr.Radio(["detection", "segment", "inpainting"], value="detection",
305
- label='Task type:',interactive=True, visible=True)
306
- text_prompt = gr.Textbox(label="Detection Prompt")
307
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt", visible=True)
308
  run_button = gr.Button(label="Run")
309
  with gr.Accordion("Advanced options", open=False):
 
209
 
210
  size = image_pil.size
211
 
212
+ if task_type == 'segment' or task_type == 'inpainting':
213
  # initialize SAM
214
  predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
215
  image = np.array(image_path)
 
233
 
234
  # masks: [1, 1, 512, 512]
235
 
236
+ if task_type == 'detection':
237
  pred_dict = {
238
  "boxes": boxes_filt,
239
  "size": [size[1], size[0]], # H,W
 
245
  image_with_box.save(image_path)
246
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
247
  return image_result
248
+ elif task_type == 'segment':
249
  assert sam_checkpoint, 'sam_checkpoint is not found!'
250
 
251
  # draw output image
 
302
  with gr.Column():
303
  input_image = gr.Image(source='upload', type="pil")
304
  task_type = gr.Radio(["detection", "segment", "inpainting"], value="detection",
305
+ label='Task type',interactive=True, visible=True)
306
+ text_prompt = gr.Textbox(label="Detection Prompt", placeholder="Cannot be empty")
307
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt", visible=True)
308
  run_button = gr.Button(label="Run")
309
  with gr.Accordion("Advanced options", open=False):