File size: 1,846 Bytes
fbf7e95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse

import numpy as np
import onnxruntime
import pandas as pd

from marcai.utils import load_config


def sigmoid(x):
    return 1 / (1 + np.exp(-1 * x))


def predict_onnx(model_onnx_path, data):
    ort_session = onnxruntime.InferenceSession(model_onnx_path)

    x = data.to_numpy(dtype=np.float32)

    input_name = ort_session.get_inputs()[0].name
    ort_inputs = {input_name: x}
    ort_outs = np.array(ort_session.run(None, ort_inputs))
    ort_outs = sigmoid(ort_outs)

    return ort_outs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i", "--input", help="Path to preprocessed data file", required=True
    )
    parser.add_argument("-o", "--output", help="Output path", required=True)
    parser.add_argument(
        "-m",
        "--model-dir",
        help="Directory containing model ONNX and YAML files",
        required=True,
    )
    parser.add_argument(
        "--chunksize",
        help="Chunk size for reading and predicting",
        default=1024,
        type=int,
    )

    args = parser.parse_args()

    config_path = f"{args.model_dir}/config.yaml"
    model_onnx = f"{args.model_dir}/model.onnx"

    config = load_config(config_path)

    # Load data
    data = pd.read_csv(args.input, chunksize=args.chunksize)

    written = False
    for chunk in data:
        # Limit columns to model input features
        input_df = chunk[config["model"]["features"]]

        prediction = predict_onnx(model_onnx, input_df)

        # Add prediction to chunk
        chunk["prediction"] = prediction.squeeze()

        # Append to CSV
        if not written:
            chunk.to_csv(args.output, index=False)
            written = True
        else:
            chunk.to_csv(args.output, mode="a", header=False, index=False)


if __name__ == "__main__":
    main()