52Hz commited on
Commit
dd6626e
1 Parent(s): ec25784

Update app_predict.py

Browse files
Files changed (1) hide show
  1. app_predict.py +3 -3
app_predict.py CHANGED
@@ -21,7 +21,7 @@ model = ['swin',
21
  'vit',
22
  'convnext']
23
 
24
- def main():
25
  parser = argparse.ArgumentParser(description='Quick demo Image Classification')
26
  parser.add_argument('--input_dir', default='./test/', type=str, help='Input images root')
27
  parser.add_argument('--result_dir', default='./result/', type=str, help='Results images root')
@@ -29,7 +29,7 @@ def main():
29
  parser.add_argument('--model', default='convnext', type=str, help='Classifier')
30
 
31
  args = parser.parse_args()
32
-
33
  inp_dir = args.input_dir
34
  out_dir = args.result_dir
35
  os.makedirs(out_dir, exist_ok=True)
@@ -290,4 +290,4 @@ def build_ensemble_model(model: dict, pretrained: bool):
290
  return ensemble_model
291
 
292
  if __name__ == '__main__':
293
- main()
 
21
  'vit',
22
  'convnext']
23
 
24
+ def main(input_model=None):
25
  parser = argparse.ArgumentParser(description='Quick demo Image Classification')
26
  parser.add_argument('--input_dir', default='./test/', type=str, help='Input images root')
27
  parser.add_argument('--result_dir', default='./result/', type=str, help='Results images root')
 
29
  parser.add_argument('--model', default='convnext', type=str, help='Classifier')
30
 
31
  args = parser.parse_args()
32
+ args.model = input_model
33
  inp_dir = args.input_dir
34
  out_dir = args.result_dir
35
  os.makedirs(out_dir, exist_ok=True)
 
290
  return ensemble_model
291
 
292
  if __name__ == '__main__':
293
+ main(model=None)