stevengrove commited on
Commit
2ed4a37
1 Parent(s): e789019

Update tools/demo.py

Browse files
Files changed (1) hide show
  1. tools/demo.py +5 -6
tools/demo.py CHANGED
@@ -116,10 +116,9 @@ def export_model(runner,
116
  # dry run
117
  deploy_model(fake_input)
118
 
119
- os.makedirs(args.work_dir, exist_ok=True)
120
  save_onnx_path = os.path.join(
121
- args.work_dir,
122
- os.path.basename(args.checkpoint).replace('pth', 'onnx'))
123
  # export onnx
124
  with BytesIO() as f:
125
  output_names = ['num_dets', 'boxes', 'scores', 'labels']
@@ -142,7 +141,7 @@ def export_model(runner,
142
  return gr.update(visible=True), save_onnx_path
143
 
144
 
145
- def demo(runner, args):
146
  with gr.Blocks(title="YOLO-World") as demo:
147
  with gr.Row():
148
  gr.Markdown('<h1><center>YOLO-World: Real-Time Open-Vocabulary '
@@ -195,7 +194,7 @@ def demo(runner, args):
195
  [output_image])
196
  clear.click(lambda: [[], '', ''], None,
197
  [image, input_text, output_image])
198
- export.click(partial(export_model, runner, args.checkpoint),
199
  [input_text, max_num_boxes, score_thr, nms_thr],
200
  [out_download, out_download])
201
  demo.launch(server_name='0.0.0.0')
@@ -228,4 +227,4 @@ if __name__ == '__main__':
228
  pipeline = cfg.test_dataloader.dataset.pipeline
229
  runner.pipeline = Compose(pipeline)
230
  runner.model.eval()
231
- demo(runner, args)
 
116
  # dry run
117
  deploy_model(fake_input)
118
 
119
+ os.makedirs('work_dirs', exist_ok=True)
120
  save_onnx_path = os.path.join(
121
+ 'work_dirs', 'yolow-l.onnx')
 
122
  # export onnx
123
  with BytesIO() as f:
124
  output_names = ['num_dets', 'boxes', 'scores', 'labels']
 
141
  return gr.update(visible=True), save_onnx_path
142
 
143
 
144
+ def demo(runner, args, cfg):
145
  with gr.Blocks(title="YOLO-World") as demo:
146
  with gr.Row():
147
  gr.Markdown('<h1><center>YOLO-World: Real-Time Open-Vocabulary '
 
194
  [output_image])
195
  clear.click(lambda: [[], '', ''], None,
196
  [image, input_text, output_image])
197
+ export.click(partial(export_model, runner, cfg.checkpoint),
198
  [input_text, max_num_boxes, score_thr, nms_thr],
199
  [out_download, out_download])
200
  demo.launch(server_name='0.0.0.0')
 
227
  pipeline = cfg.test_dataloader.dataset.pipeline
228
  runner.pipeline = Compose(pipeline)
229
  runner.model.eval()
230
+ demo(runner, args, cfg)