Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
from fastbook import * | |
from fastai.vision import * | |
import argparse | |
import sys | |
import os | |
from contextlib import redirect_stdout | |
#from PIL import Image | |
import torch | |
def main(): | |
global args, out | |
out = sys.stdout | |
# out = debogrify_stdout() | |
parser = argparse.ArgumentParser(description='Classify images with a trained neural network.') | |
parser.add_argument('-v', '--verbose', action='store_true', help='debug') | |
parser.add_argument('--path', type=str, default='.', help='path to model') | |
parser.add_argument('--model', type=str, default='export.pkl', help='model') | |
parser.add_argument('--cpu', action='store_true', help='run on cpu') | |
parser.add_argument('--batch', type=int, default=100, help='batch size') | |
parser.add_argument('--move', type=float, help='move files with certainty above the threshold') | |
# parser.add_argument('--train', action='store_true', help='train the model') | |
# parser.add_argument('--watch', type=str, help='folder to watch for images to classify') | |
# parser.add_argument('--folder', type=str, default='.', help='folder with images to classify (or train)') | |
# parser.add_argument('--certain', type=float, default=0.9, help='minimum certainty to classify image') | |
args = parser.parse_args() | |
if args.verbose: | |
print(torch.cuda.get_device_name(0), file=sys.stderr) | |
model_path = Path(args.path)/Path(args.model) | |
if not model_path.exists(): | |
model_path = Path.home()/Path(args.model) | |
if args.verbose: | |
print(model_path, file=sys.stderr) | |
learn = load_learner(model_path) | |
if not args.cpu: | |
try: | |
learn.dls.to('cuda') | |
except Exception as e: | |
print(e, file=sys.stderr) | |
if args.verbose: | |
print(f'{learn.dls.device=}', file=sys.stderr) | |
if args.batch == 1: | |
predict_one_at_time(learn) | |
else: | |
predict_in_batches(learn) | |
#def debogrify_stdout(): | |
# """ deal with some bogus output to stdout from pytorch or something """ | |
# stdout_fd = sys.stdout.fileno() | |
# stderr_fd = sys.stderr.fileno() | |
# new_stdout_fd = os.dup(stdout_fd) | |
# os.dup2(stderr_fd, stdout_fd) | |
# out = os.fdopen(new_stdout_fd, "w") | |
# return out | |
def predict_one_at_time(learn): | |
for filename in sys.stdin: | |
filename = filename.rstrip() | |
pred,i,probs = learn.predict(filename) | |
print(f"{probs[i]:.10f}\t{pred}\t{filename}", file=out) | |
# for a in zip(learn.dls.vocab,[f'{x:.10f}' for x in probs]): | |
# print(a) | |
def predict_in_batches(learn): | |
if args.verbose: | |
print(f'batch size {args.batch}', file=sys.stderr) | |
batch = [] | |
for filename in sys.stdin: | |
filename = filename.rstrip() | |
batch.append(filename) | |
if len(batch) >= args.batch: | |
predict_batch(learn, batch) | |
batch = [] | |
if len(batch) > 0: | |
predict_batch(learn, batch) | |
batch = [] | |
def predict_batch(learn, batch): | |
vocab = learn.dls.vocab | |
with redirect_stdout(sys.stderr): | |
# In some versions the following call writes the progress bar to stdout. We can't have that! | |
dl = learn.dls.test_dl(batch) | |
preds_all = learn.get_preds(dl=dl) | |
for i in range(0, len(batch)): | |
filename = batch[i] | |
preds = preds_all[0][i] | |
i = preds.argmax(dim=0) | |
label = vocab[i] | |
prob = preds[i] | |
print(f"{prob:.10f}\t{label}\t{filename}", file=out) | |
if args.move is not None and prob >= args.move: | |
Path(label).mkdir(exist_ok=True) | |
shutil.move(filename, label) | |
# TODO load ai_helper.py | |
#def is_cat(x): return x[0].isupper() | |
main() | |