# copyright zxix 2022
# https://creativecommons.org/licenses/by-nc-sa/4.0/
import torch
import pickle_inspector
import sys
from pathlib import Path

debug = len(sys.argv) == 3

dir = sys.argv[1]
print("checking dir: " + dir)

BASE_DIR = Path(dir)
EXTENSIONS = {'.pt', '.bin', '.ckpt'}
BAD_CALLS = {'os', 'shutil', 'sys', 'requests', 'net'}
BAD_SIGNAL = {'rm ', 'cat ', 'nc ', '/bin/sh '}

for path in BASE_DIR.glob(r'**/*'):
  if path.suffix in EXTENSIONS:
    print("")
    print("..." + path.as_posix())
    result = torch.load(path.as_posix(), pickle_module=pickle_inspector.pickle)
    result_total = 0
    result_other = 0
    result_calls = {}
    result_signals = {}
    result_output = ""

    for call in BAD_CALLS:
      result_calls[call] = 0

    for signal in BAD_SIGNAL:
      result_signals[signal] = 0

    for c in result.calls:
      for call in BAD_CALLS:
        if (c.find(call + ".") == 0):
          result_calls[call] += 1
          result_total += 1
          result_output += "\n--- found lib call (" + call + ") ---\n"
          result_output += c
          result_output += "\n---------------\n"
          break
      for signal in BAD_SIGNAL:
        if (c.find(signal) > -1):
          result_signals[signal] += 1
          result_total += 1
          result_output += "\n--- found malicious signal (" + signal + ") ---\n"
          result_output += c
          result_output += "\n---------------\n"
          break

      if (
        c.find("numpy.") != 0 and 
        c.find("_codecs.") != 0 and 
        c.find("collections.") != 0 and 
        c.find("torch.") != 0):
        result_total += 1
        result_other += 1
        result_output += "\n--- found non-standard lib call ---\n"
        result_output += c
        result_output += "\n---------------\n"

    if (result_total > 0):
      for call in BAD_CALLS:
        print("library call (" + call + ".): " + str(result_calls[call]))
      for signal in BAD_SIGNAL:
        print("malicious signal (" + signal + "): " + str(result_signals[signal]))
      print("non-standard calls: " + str(result_other))
      print("total: " + str(result_total))
      print("")
      print("SCAN FAILED")

      if (debug):
        print(result_output)
    else:
      print("SCAN PASSED!")