#!/usr/bin/env python3 from fastai.vision.all import * import argparse import sys import os from contextlib import redirect_stdout #from PIL import Image import torch import multiprocessing nproc = multiprocessing.cpu_count() def main(): global args, out out = sys.stdout # out = debogrify_stdout() parser = argparse.ArgumentParser(description='Classify images with a trained neural network.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-v', '--verbose', action='store_true', help='debug') parser.add_argument('--path', type=str, default='.', help='path to model') parser.add_argument('--model', type=str, default='export.pkl', help='model') parser.add_argument('--cpu', action='store_true', help='run on cpu') parser.add_argument('--batch', type=int, default=10000, help='batch size') parser.add_argument('--move', type=float, help='move files with certainty above the threshold') parser.add_argument('--workers', type=int, default=nproc, help='max number of data loader workers') # parser.add_argument('--train', action='store_true', help='train the model') # parser.add_argument('--watch', type=str, help='folder to watch for images to classify') # parser.add_argument('--folder', type=str, default='.', help='folder with images to classify (or train)') # parser.add_argument('--certain', type=float, default=0.9, help='minimum certainty to classify image') args = parser.parse_args() if args.verbose: print(torch.cuda.get_device_name(0), file=sys.stderr) model_path = Path(args.path)/Path(args.model) if not model_path.exists(): model_path = Path.home()/Path(args.model) if args.verbose: print(model_path, file=sys.stderr) learn = load_learner(model_path) if args.cpu: torch.set_num_threads(nproc) else: try: learn.dls.to('cuda') except Exception as e: print(e, file=sys.stderr) if args.verbose: print(f'{learn.dls.device=}', file=sys.stderr) if args.batch == 1: predict_one_at_time(learn) else: predict_in_batches(learn) def predict_one_at_time(learn): for filename in sys.stdin: filename = filename.rstrip() pred,i,probs = learn.predict(filename) print(f"{probs[i]:.10f}\t{pred}\t{filename}", file=out) # for a in zip(learn.dls.vocab,[f'{x:.10f}' for x in probs]): # print(a) def predict_in_batches(learn): if args.verbose: print(f'batch size {args.batch}', file=sys.stderr) batch = [] for filename in sys.stdin: filename = filename.rstrip() batch.append(filename) if len(batch) >= args.batch: predict_batch(learn, batch) batch = [] if len(batch) > 0: predict_batch(learn, batch) batch = [] def predict_batch(learn, batch): vocab = learn.dls.vocab with redirect_stdout(sys.stderr): # In some versions the following call writes the progress bar to stdout. We can't have that! dl = learn.dls.test_dl(batch, num_workers=min(args.workers, args.batch)) preds_all = learn.get_preds(dl=dl) for i in range(0, len(batch)): filename = batch[i] preds = preds_all[0][i] i = preds.argmax(dim=0) label = vocab[i] prob = preds[i] print(f"{prob:.10f}\t{label}\t{filename}", file=out) if args.move is not None and prob >= args.move: Path(label).mkdir(exist_ok=True) shutil.move(filename, label) # TODO load ai_helper.py #def is_cat(x): return x[0].isupper() main()