sswam commited on
Commit
6fa34a8
1 Parent(s): 42318b3

add notebook and my classify script

Browse files
Files changed (2) hide show
  1. classify +113 -0
  2. photo-checker.ipynb +0 -0
classify ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from fastbook import *
4
+ from fastai.vision import *
5
+ import argparse
6
+ import sys
7
+ import os
8
+ from contextlib import redirect_stdout
9
+ #from PIL import Image
10
+
11
+ import torch
12
+
13
+ def main():
14
+ global args, out
15
+
16
+ out = sys.stdout
17
+ # out = debogrify_stdout()
18
+
19
+ parser = argparse.ArgumentParser(description='Classify images with a trained neural network.')
20
+ parser.add_argument('-v', '--verbose', action='store_true', help='debug')
21
+ parser.add_argument('--path', type=str, default='.', help='path to model')
22
+ parser.add_argument('--model', type=str, default='export.pkl', help='model')
23
+ parser.add_argument('--cpu', action='store_true', help='run on cpu')
24
+ parser.add_argument('--batch', type=int, default=100, help='batch size')
25
+ parser.add_argument('--move', type=float, help='move files with certainty above the threshold')
26
+ # parser.add_argument('--train', action='store_true', help='train the model')
27
+ # parser.add_argument('--watch', type=str, help='folder to watch for images to classify')
28
+ # parser.add_argument('--folder', type=str, default='.', help='folder with images to classify (or train)')
29
+ # parser.add_argument('--certain', type=float, default=0.9, help='minimum certainty to classify image')
30
+
31
+ args = parser.parse_args()
32
+
33
+ if args.verbose:
34
+ print(torch.cuda.get_device_name(0), file=sys.stderr)
35
+
36
+ model_path = Path(args.path)/Path(args.model)
37
+ if not model_path.exists():
38
+ model_path = Path.home()/Path(args.model)
39
+
40
+ if args.verbose:
41
+ print(model_path, file=sys.stderr)
42
+
43
+ learn = load_learner(model_path)
44
+ if not args.cpu:
45
+ try:
46
+ learn.dls.to('cuda')
47
+ except Exception as e:
48
+ print(e, file=sys.stderr)
49
+
50
+ if args.verbose:
51
+ print(f'{learn.dls.device=}', file=sys.stderr)
52
+
53
+ if args.batch == 1:
54
+ predict_one_at_time(learn)
55
+ else:
56
+ predict_in_batches(learn)
57
+
58
+
59
+ #def debogrify_stdout():
60
+ # """ deal with some bogus output to stdout from pytorch or something """
61
+ # stdout_fd = sys.stdout.fileno()
62
+ # stderr_fd = sys.stderr.fileno()
63
+ # new_stdout_fd = os.dup(stdout_fd)
64
+ # os.dup2(stderr_fd, stdout_fd)
65
+ # out = os.fdopen(new_stdout_fd, "w")
66
+ # return out
67
+
68
+
69
+ def predict_one_at_time(learn):
70
+ for filename in sys.stdin:
71
+ filename = filename.rstrip()
72
+ pred,i,probs = learn.predict(filename)
73
+ print(f"{probs[i]:.10f}\t{pred}\t{filename}", file=out)
74
+ # for a in zip(learn.dls.vocab,[f'{x:.10f}' for x in probs]):
75
+ # print(a)
76
+
77
+
78
+ def predict_in_batches(learn):
79
+ if args.verbose:
80
+ print(f'batch size {args.batch}', file=sys.stderr)
81
+ batch = []
82
+ for filename in sys.stdin:
83
+ filename = filename.rstrip()
84
+ batch.append(filename)
85
+ if len(batch) >= args.batch:
86
+ predict_batch(learn, batch)
87
+ batch = []
88
+ if len(batch) > 0:
89
+ predict_batch(learn, batch)
90
+ batch = []
91
+
92
+
93
+ def predict_batch(learn, batch):
94
+ vocab = learn.dls.vocab
95
+ with redirect_stdout(sys.stderr):
96
+ # In some versions the following call writes the progress bar to stdout. We can't have that!
97
+ dl = learn.dls.test_dl(batch)
98
+ preds_all = learn.get_preds(dl=dl)
99
+ for i in range(0, len(batch)):
100
+ filename = batch[i]
101
+ preds = preds_all[0][i]
102
+ i = preds.argmax(dim=0)
103
+ label = vocab[i]
104
+ prob = preds[i]
105
+ print(f"{prob:.10f}\t{label}\t{filename}", file=out)
106
+ if args.move is not None and prob >= args.move:
107
+ Path(label).mkdir(exist_ok=True)
108
+ shutil.move(filename, label)
109
+
110
+ # TODO load ai_helper.py
111
+ #def is_cat(x): return x[0].isupper()
112
+
113
+ main()
photo-checker.ipynb ADDED
The diff for this file is too large to render. See raw diff