photo-checker / classify
sswam's picture
add notebook and my classify script
6fa34a8
raw
history blame
3.33 kB
#!/usr/bin/env python3
from fastbook import *
from fastai.vision import *
import argparse
import sys
import os
from contextlib import redirect_stdout
#from PIL import Image
import torch
def main():
global args, out
out = sys.stdout
# out = debogrify_stdout()
parser = argparse.ArgumentParser(description='Classify images with a trained neural network.')
parser.add_argument('-v', '--verbose', action='store_true', help='debug')
parser.add_argument('--path', type=str, default='.', help='path to model')
parser.add_argument('--model', type=str, default='export.pkl', help='model')
parser.add_argument('--cpu', action='store_true', help='run on cpu')
parser.add_argument('--batch', type=int, default=100, help='batch size')
parser.add_argument('--move', type=float, help='move files with certainty above the threshold')
# parser.add_argument('--train', action='store_true', help='train the model')
# parser.add_argument('--watch', type=str, help='folder to watch for images to classify')
# parser.add_argument('--folder', type=str, default='.', help='folder with images to classify (or train)')
# parser.add_argument('--certain', type=float, default=0.9, help='minimum certainty to classify image')
args = parser.parse_args()
if args.verbose:
print(torch.cuda.get_device_name(0), file=sys.stderr)
model_path = Path(args.path)/Path(args.model)
if not model_path.exists():
model_path = Path.home()/Path(args.model)
if args.verbose:
print(model_path, file=sys.stderr)
learn = load_learner(model_path)
if not args.cpu:
try:
learn.dls.to('cuda')
except Exception as e:
print(e, file=sys.stderr)
if args.verbose:
print(f'{learn.dls.device=}', file=sys.stderr)
if args.batch == 1:
predict_one_at_time(learn)
else:
predict_in_batches(learn)
#def debogrify_stdout():
# """ deal with some bogus output to stdout from pytorch or something """
# stdout_fd = sys.stdout.fileno()
# stderr_fd = sys.stderr.fileno()
# new_stdout_fd = os.dup(stdout_fd)
# os.dup2(stderr_fd, stdout_fd)
# out = os.fdopen(new_stdout_fd, "w")
# return out
def predict_one_at_time(learn):
for filename in sys.stdin:
filename = filename.rstrip()
pred,i,probs = learn.predict(filename)
print(f"{probs[i]:.10f}\t{pred}\t{filename}", file=out)
# for a in zip(learn.dls.vocab,[f'{x:.10f}' for x in probs]):
# print(a)
def predict_in_batches(learn):
if args.verbose:
print(f'batch size {args.batch}', file=sys.stderr)
batch = []
for filename in sys.stdin:
filename = filename.rstrip()
batch.append(filename)
if len(batch) >= args.batch:
predict_batch(learn, batch)
batch = []
if len(batch) > 0:
predict_batch(learn, batch)
batch = []
def predict_batch(learn, batch):
vocab = learn.dls.vocab
with redirect_stdout(sys.stderr):
# In some versions the following call writes the progress bar to stdout. We can't have that!
dl = learn.dls.test_dl(batch)
preds_all = learn.get_preds(dl=dl)
for i in range(0, len(batch)):
filename = batch[i]
preds = preds_all[0][i]
i = preds.argmax(dim=0)
label = vocab[i]
prob = preds[i]
print(f"{prob:.10f}\t{label}\t{filename}", file=out)
if args.move is not None and prob >= args.move:
Path(label).mkdir(exist_ok=True)
shutil.move(filename, label)
# TODO load ai_helper.py
#def is_cat(x): return x[0].isupper()
main()