# copyright zxix 2022 # https://creativecommons.org/licenses/by-nc-sa/4.0/ import torch import pickle_inspector import sys from pathlib import Path debug = len(sys.argv) == 3 dir = sys.argv[1] print("checking dir: " + dir) BASE_DIR = Path(dir) EXTENSIONS = {'.pt', '.bin', '.ckpt'} BAD_CALLS = {'os', 'shutil', 'sys', 'requests', 'net'} BAD_SIGNAL = {'rm ', 'cat ', 'nc ', '/bin/sh '} for path in BASE_DIR.glob(r'**/*'): if path.suffix in EXTENSIONS: print("") print("..." + path.as_posix()) result = torch.load(path.as_posix(), pickle_module=pickle_inspector.pickle) result_total = 0 result_other = 0 result_calls = {} result_signals = {} result_output = "" for call in BAD_CALLS: result_calls[call] = 0 for signal in BAD_SIGNAL: result_signals[signal] = 0 for c in result.calls: for call in BAD_CALLS: if (c.find(call + ".") == 0): result_calls[call] += 1 result_total += 1 result_output += "\n--- found lib call (" + call + ") ---\n" result_output += c result_output += "\n---------------\n" break for signal in BAD_SIGNAL: if (c.find(signal) > -1): result_signals[signal] += 1 result_total += 1 result_output += "\n--- found malicious signal (" + signal + ") ---\n" result_output += c result_output += "\n---------------\n" break if ( c.find("numpy.") != 0 and c.find("_codecs.") != 0 and c.find("collections.") != 0 and c.find("torch.") != 0): result_total += 1 result_other += 1 result_output += "\n--- found non-standard lib call ---\n" result_output += c result_output += "\n---------------\n" if (result_total > 0): for call in BAD_CALLS: print("library call (" + call + ".): " + str(result_calls[call])) for signal in BAD_SIGNAL: print("malicious signal (" + signal + "): " + str(result_signals[signal])) print("non-standard calls: " + str(result_other)) print("total: " + str(result_total)) print("") print("SCAN FAILED") if (debug): print(result_output) else: print("SCAN PASSED!")