sswam commited on
Commit
379cfa5
1 Parent(s): ccfac5c

improved classify script

Browse files
Files changed (1) hide show
  1. classify +11 -6
classify CHANGED
@@ -1,7 +1,6 @@
1
  #!/usr/bin/env python3
2
 
3
- from fastbook import *
4
- from fastai.vision import *
5
  import argparse
6
  import sys
7
  import os
@@ -10,19 +9,23 @@ from contextlib import redirect_stdout
10
 
11
  import torch
12
 
 
 
 
13
  def main():
14
  global args, out
15
 
16
  out = sys.stdout
17
  # out = debogrify_stdout()
18
 
19
- parser = argparse.ArgumentParser(description='Classify images with a trained neural network.')
20
  parser.add_argument('-v', '--verbose', action='store_true', help='debug')
21
  parser.add_argument('--path', type=str, default='.', help='path to model')
22
  parser.add_argument('--model', type=str, default='export.pkl', help='model')
23
  parser.add_argument('--cpu', action='store_true', help='run on cpu')
24
- parser.add_argument('--batch', type=int, default=100, help='batch size')
25
  parser.add_argument('--move', type=float, help='move files with certainty above the threshold')
 
26
  # parser.add_argument('--train', action='store_true', help='train the model')
27
  # parser.add_argument('--watch', type=str, help='folder to watch for images to classify')
28
  # parser.add_argument('--folder', type=str, default='.', help='folder with images to classify (or train)')
@@ -41,7 +44,9 @@ def main():
41
  print(model_path, file=sys.stderr)
42
 
43
  learn = load_learner(model_path)
44
- if not args.cpu:
 
 
45
  try:
46
  learn.dls.to('cuda')
47
  except Exception as e:
@@ -94,7 +99,7 @@ def predict_batch(learn, batch):
94
  vocab = learn.dls.vocab
95
  with redirect_stdout(sys.stderr):
96
  # In some versions the following call writes the progress bar to stdout. We can't have that!
97
- dl = learn.dls.test_dl(batch)
98
  preds_all = learn.get_preds(dl=dl)
99
  for i in range(0, len(batch)):
100
  filename = batch[i]
 
1
  #!/usr/bin/env python3
2
 
3
+ from fastai.vision.all import *
 
4
  import argparse
5
  import sys
6
  import os
 
9
 
10
  import torch
11
 
12
+ import multiprocessing
13
+ nproc = multiprocessing.cpu_count()
14
+
15
  def main():
16
  global args, out
17
 
18
  out = sys.stdout
19
  # out = debogrify_stdout()
20
 
21
+ parser = argparse.ArgumentParser(description='Classify images with a trained neural network.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22
  parser.add_argument('-v', '--verbose', action='store_true', help='debug')
23
  parser.add_argument('--path', type=str, default='.', help='path to model')
24
  parser.add_argument('--model', type=str, default='export.pkl', help='model')
25
  parser.add_argument('--cpu', action='store_true', help='run on cpu')
26
+ parser.add_argument('--batch', type=int, default=10000, help='batch size')
27
  parser.add_argument('--move', type=float, help='move files with certainty above the threshold')
28
+ parser.add_argument('--workers', type=int, help='max number of data loader workers', default=nproc)
29
  # parser.add_argument('--train', action='store_true', help='train the model')
30
  # parser.add_argument('--watch', type=str, help='folder to watch for images to classify')
31
  # parser.add_argument('--folder', type=str, default='.', help='folder with images to classify (or train)')
 
44
  print(model_path, file=sys.stderr)
45
 
46
  learn = load_learner(model_path)
47
+ if args.cpu:
48
+ torch.set_num_threads(nproc)
49
+ else:
50
  try:
51
  learn.dls.to('cuda')
52
  except Exception as e:
 
99
  vocab = learn.dls.vocab
100
  with redirect_stdout(sys.stderr):
101
  # In some versions the following call writes the progress bar to stdout. We can't have that!
102
+ dl = learn.dls.test_dl(batch, num_workers=min(args.workers, args.batch))
103
  preds_all = learn.get_preds(dl=dl)
104
  for i in range(0, len(batch)):
105
  filename = batch[i]