udyan2 commited on
Commit
bf6680f
1 Parent(s): 9ea24fa

Upload run_inference.py

Browse files
Files changed (1) hide show
  1. run_inference.py +189 -0
run_inference.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Run inference benchmarks
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import pathlib
7
+ import time
8
+
9
+ import numpy as np
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+ from transformers import AutoModelForSequenceClassification
13
+ from transformers import BertConfig, BertForSequenceClassification
14
+
15
+ from utils.process_data import read_and_preprocess_data, REVERSE_MAPPING
16
+
17
+
18
+ def inference(predict_fn, batch, n_runs) -> float:
19
+ """Run inference using the provided `predict_fn`
20
+ Args:
21
+ predict_fn: prediction function to use
22
+ batch: data batch from a data loader
23
+ n_runs: number of benchmark runs to time
24
+ Returns:
25
+ float : Average prediction time
26
+ """
27
+ times = []
28
+ predictions = []
29
+ with torch.no_grad():
30
+ for _ in range(2 + n_runs):
31
+ start = time.time()
32
+ res = predict_fn(batch)
33
+ end = time.time()
34
+ predictions.append(res)
35
+ times.append(end - start)
36
+
37
+ avg_time = np.mean(times[2:])
38
+ return avg_time
39
+
40
+
41
+ def main(flags) -> None:
42
+ """Setup model for inference and perform benchmarking
43
+ Args:
44
+ FLAGS: benchmarking flags
45
+ """
46
+
47
+ if flags.logfile == "":
48
+ logging.basicConfig(level=logging.DEBUG)
49
+ else:
50
+ path = pathlib.Path(flags.logfile)
51
+ path.parent.mkdir(parents=True, exist_ok=True)
52
+ logging.basicConfig(filename=flags.logfile, level=logging.DEBUG)
53
+ logger = logging.getLogger()
54
+
55
+ if not os.path.exists(flags.saved_model_dir):
56
+ logger.error("Saved model %s not found!", flags.saved_model_dir)
57
+ return
58
+
59
+ # Load dataset into memory
60
+ tokenizer = AutoTokenizer.from_pretrained(flags.saved_model_dir)
61
+
62
+ try:
63
+ test_dataset = read_and_preprocess_data(
64
+ flags.input_file,
65
+ tokenizer,
66
+ max_length=flags.seq_length,
67
+ include_label=False
68
+ )
69
+ test_loader = torch.utils.data.DataLoader(
70
+ test_dataset, batch_size=flags.batch_size, shuffle=False
71
+ )
72
+ except FileNotFoundError as exc:
73
+ logger.error("Please follow instructions to download data.")
74
+ logger.error(exc, exc_info=True)
75
+ return
76
+
77
+ # Load model into memory, if INC, need special loading
78
+ model = AutoModelForSequenceClassification.from_pretrained(flags.saved_model_dir)
79
+
80
+ # JIT model for faster execution
81
+ batch = next(iter(test_loader))
82
+ token_ids = batch['input_ids']
83
+ mask = batch['attention_mask']
84
+
85
+ jit_inputs = (token_ids, mask)
86
+
87
+ logger.info("Using stock model")
88
+
89
+ model.eval()
90
+ model = torch.jit.trace(model,jit_inputs,check_trace=False,strict=False)
91
+ model = torch.jit.freeze(model)
92
+
93
+ def predict(
94
+ batch
95
+ ) -> torch.Tensor:
96
+ """Predicts the output for the given batch
97
+ using the given PyTorch model.
98
+ Args:
99
+ batch (torch.Tensor): data batch from data loader
100
+ transformers tokenizer
101
+ Returns:
102
+ torch.Tensor: predicted quantities
103
+ """
104
+
105
+ return model(
106
+ input_ids=batch['input_ids'],
107
+ attention_mask=batch['attention_mask'],
108
+ )
109
+
110
+ if flags.benchmark_mode:
111
+ logger.info("Running experiment n = %d, b = %d, l = %d",
112
+ flags.n_runs, flags.batch_size, flags.seq_length)
113
+
114
+ average_time = inference(predict, batch, FLAGS.n_runs)
115
+ logger.info('Avg time per batch : %.3f s', average_time)
116
+ else:
117
+ predictions = []
118
+ index = 0
119
+ for _, batch in enumerate(test_loader):
120
+ pred_probs = torch.softmax(
121
+ predict(batch)['logits'], axis=1
122
+ ).detach().numpy()
123
+ for i in range(len(pred_probs)):
124
+ probs = {
125
+ REVERSE_MAPPING[x]: pred_probs[i, x]
126
+ for x in np.argsort(pred_probs[i, :])[::-1][:5]
127
+ }
128
+ predictions.append(
129
+ {'id': index, 'prognosis': probs}
130
+ )
131
+ index += 1
132
+ print({"predictions": predictions})
133
+
134
+
135
+ if __name__ == '__main__':
136
+ parser = argparse.ArgumentParser()
137
+
138
+ parser.add_argument(
139
+ '--saved_model_dir',
140
+ required=True,
141
+ help="saved pretrained model to benchmark",
142
+ type=str
143
+ )
144
+
145
+ parser.add_argument(
146
+ '--input_file',
147
+ required=True,
148
+ help="input to make predictions on",
149
+ type=str
150
+ )
151
+
152
+ parser.add_argument(
153
+ '--batch_size',
154
+ default=-1,
155
+ type=int,
156
+ help="batch size to use. if -1, uses all entries in input."
157
+ )
158
+
159
+ parser.add_argument(
160
+ '--benchmark_mode',
161
+ default=False,
162
+ help="Benchmark instead of get predictions.",
163
+ action="store_true"
164
+ )
165
+
166
+ parser.add_argument(
167
+ '--seq_length',
168
+ default=512,
169
+ help="sequence length to use. defaults to 512.",
170
+ type=int
171
+ )
172
+
173
+ parser.add_argument(
174
+ '--logfile',
175
+ help="logfile to use.",
176
+ default="",
177
+ type=str
178
+ )
179
+
180
+ parser.add_argument(
181
+ '--n_runs',
182
+ default=100,
183
+ help="number of trials to test. defaults to 100.",
184
+ type=int
185
+ )
186
+
187
+ FLAGS = parser.parse_args()
188
+
189
+ main(FLAGS)