MickyMike commited on
Commit
f513a95
1 Parent(s): ee0fa10

Upload 14 files

Browse files
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import pandas as pd
4
+ from utils import *
5
+
6
+ PATH = os.getcwd()
7
+
8
+ if __name__ == "__main__":
9
+ MAX_NUM_STATEMENTS = 155
10
+
11
+ st.set_page_config(page_title="AIBugHunter")
12
+ # sidebar
13
+ st.sidebar.title("AIBugHunter Web App")
14
+ behavior = st.sidebar.selectbox(label="NAVIGATOR IS HERE:",
15
+ options=["DEMO", "Analyze my own"])
16
+ if behavior == "DEMO":
17
+ # function title
18
+ st.title("C/C++ Vulnerability Dataset Viewer")
19
+ dataset_path = PATH + "/data/test.csv"
20
+ st.dataframe(pd.read_csv(dataset_path))
21
+
22
+ with st.form("input_form_a"):
23
+ idx = st.selectbox('Select an index', (str(i) for i in range(100)))
24
+ sub = st.form_submit_button("Select")
25
+ if sub:
26
+ idx = int(idx)
27
+ df = pd.read_csv(dataset_path)
28
+ input_code = df["function"][idx]
29
+
30
+ input_code = input_code.split("\n")[:MAX_NUM_STATEMENTS]
31
+ input_code = "\n".join(input_code)
32
+ # load model
33
+ with st.spinner("Scanning security issues..."):
34
+ # do inference
35
+ out = predict_vul_lines([input_code])
36
+ func_pred = out["batch_func_pred"][0]
37
+ func_confidence = out["batch_func_pred_prob"][0]
38
+ line_pred = out["batch_statement_pred"][0]
39
+ line_confidence = out["batch_statement_pred_prob"][0]
40
+ output = None
41
+ # inference complete
42
+ st.snow()
43
+ print_code = input_code.split("\n")[:MAX_NUM_STATEMENTS]
44
+ st.markdown("### Scanning Results:")
45
+ if func_pred == 0:
46
+ st.write("<span style='color:green'>" + "No vulnerabilities detected"+ "</span>", unsafe_allow_html=True)
47
+ st.markdown("### Non-Vulnerable Function:")
48
+ else:
49
+ for i in range(len(print_code)):
50
+ c = print_code[i]
51
+ vul = line_pred[i]
52
+ if vul == 1:
53
+ st.write(f"<span style='color:red'> Vulnerable Line {i+1} </span>", unsafe_allow_html=True)
54
+ st.code(c)
55
+ st.markdown("### Vulnerable Function:")
56
+ st.code(input_code, language="cpp", line_numbers=True)
57
+
58
+ elif behavior == "Analyze my own":
59
+ # user input of project title
60
+ ## todo- limit the input to 150 lines
61
+ with st.form("input_form_b"):
62
+ input_code = st.text_area("Input a C/C++ function:", height=275)
63
+ submitted = st.form_submit_button("Analyze")
64
+ if submitted:
65
+ # load model
66
+ with st.spinner("Scanning security issues..."):
67
+ # do inference
68
+ out = predict_vul_lines([input_code])
69
+ func_pred = out["batch_func_pred"][0]
70
+ func_confidence = out["batch_func_pred_prob"][0]
71
+ line_pred = out["batch_statement_pred"][0]
72
+ line_confidence = out["batch_statement_pred_prob"][0]
73
+ output = None
74
+ # inference complete
75
+ st.snow()
76
+ print_code = input_code.split("\n")[:MAX_NUM_STATEMENTS]
77
+ st.markdown("### Scanning Results:")
78
+ if func_pred == 0:
79
+ st.write("<span style='color:green'>" + "No vulnerabilities detected"+ "</span>", unsafe_allow_html=True)
80
+ st.markdown("### Non-Vulnerable Function:")
81
+ else:
82
+ for i in range(len(print_code)):
83
+ c = print_code[i]
84
+ vul = line_pred[i]
85
+ if vul == 1:
86
+ st.write(f"<span style='color:red'> Vulnerable Line {i+1} </span>", unsafe_allow_html=True)
87
+ st.code(c)
88
+ st.markdown("### Vulnerable Function:")
89
+ st.code(input_code, language="cpp", line_numbers=True)
data/process.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ df = pd.read_csv("./processed_test.csv")
4
+
5
+
6
+ func_lab = []
7
+ stat_lab = []
8
+ cwe_id = []
9
+ func = []
10
+
11
+ df_vul = df[df["function_label"]==1][:50]
12
+ df_vul = df_vul.reset_index()
13
+
14
+ df_non_vul = df[df["function_label"]==0][:50]
15
+ df_non_vul = df_non_vul.reset_index()
16
+
17
+ for i in range(len(df_vul)):
18
+ func_lab.append(df_vul["function_label"][i])
19
+ stat_lab.append(df_vul["statement_label"][i])
20
+
21
+ id = df_vul["cwe_id"][i]
22
+ if df_vul["function_label"][i] == 0:
23
+ id = None
24
+ cwe_id.append(id)
25
+ func.append(df_vul["func_before"][i])
26
+
27
+ func_lab.append(df_non_vul["function_label"][i])
28
+ stat_lab.append(df_non_vul["statement_label"][i])
29
+
30
+ id = df_non_vul["cwe_id"][i]
31
+ if df_non_vul["function_label"][i] == 0:
32
+ id = None
33
+
34
+ cwe_id.append(id)
35
+ func.append(df_non_vul["func_before"][i])
36
+
37
+
38
+ pd.DataFrame({"function": func, "function_label": func_lab, "cwe_id": cwe_id, "statement_label": stat_lab}).to_csv("./test.csv", index=False)
39
+
data/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
models/statement_t5_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19747f298f181dc8488dcf128991acdbf1df75e140df2ca4ecd92922cb9f16d6
3
+ size 471562706
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ pickle
4
+ numpy
5
+ onnxruntime
6
+ pandas
statement_t5.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ClassificationHead(nn.Module):
6
+ """Head for sentence-level classification tasks."""
7
+ def __init__(self, hidden_dim):
8
+ super().__init__()
9
+ self.dense = nn.Linear(hidden_dim, hidden_dim)
10
+ self.Dropout = nn.Dropout(0.1)
11
+ self.out_proj = nn.Linear(hidden_dim, 1)
12
+ self.rnn_pool = nn.GRU(input_size=768,
13
+ hidden_size=768,
14
+ num_layers=1,
15
+ batch_first=True)
16
+ self.func_dense = nn.Linear(hidden_dim, hidden_dim)
17
+ self.func_out_proj = nn.Linear(hidden_dim, 2)
18
+
19
+ def forward(self, hidden):
20
+ x = self.Dropout(hidden)
21
+ x = self.dense(x)
22
+ x = torch.tanh(x)
23
+ x = self.Dropout(x)
24
+ x = self.out_proj(x)
25
+ out, func_x = self.rnn_pool(hidden)
26
+ func_x = func_x.squeeze(0)
27
+ func_x = self.Dropout(func_x)
28
+ func_x = self.func_dense(func_x)
29
+ func_x = torch.tanh(func_x)
30
+ func_x = self.Dropout(func_x)
31
+ func_x = self.func_out_proj(func_x)
32
+ return x.squeeze(-1), func_x
33
+
34
+ class StatementT5(nn.Module):
35
+ def __init__(self, t5, tokenizer, device, hidden_dim=768):
36
+ super(StatementT5, self).__init__()
37
+ self.max_num_statement = 155
38
+ self.word_embedding = t5.shared
39
+ self.rnn_statement_embedding = nn.GRU(input_size=768,
40
+ hidden_size=768,
41
+ num_layers=1,
42
+ batch_first=True)
43
+ self.t5 = t5
44
+ self.tokenizer = tokenizer
45
+ self.device = device
46
+ # CLS head
47
+ self.classifier = ClassificationHead(hidden_dim=hidden_dim)
48
+
49
+ def forward(self, input_ids, statement_mask, labels=None, func_labels=None):
50
+ statement_mask = statement_mask[:, :self.max_num_statement]
51
+ if self.training:
52
+ embed = self.word_embedding(input_ids)
53
+ inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device)
54
+ for i in range(len(embed)):
55
+ statement_of_tokens = embed[i]
56
+ out, statement_embed = self.rnn_statement_embedding(statement_of_tokens)
57
+ inputs_embeds[i, :, :] = statement_embed
58
+ inputs_embeds = inputs_embeds[:, :self.max_num_statement, :]
59
+ rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state
60
+ logits, func_logits = self.classifier(rep)
61
+ loss_fct = nn.CrossEntropyLoss()
62
+ statement_loss = loss_fct(logits, labels)
63
+ loss_fct_2 = nn.CrossEntropyLoss()
64
+ func_loss = loss_fct_2(func_logits, func_labels)
65
+ return statement_loss, func_loss
66
+ else:
67
+ embed = self.word_embedding(input_ids)
68
+ inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device)
69
+ for i in range(len(embed)):
70
+ statement_of_tokens = embed[i]
71
+ out, statement_embed = self.rnn_statement_embedding(statement_of_tokens)
72
+ inputs_embeds[i, :, :] = statement_embed
73
+ inputs_embeds = inputs_embeds[:, :self.max_num_statement, :]
74
+ rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state
75
+ logits, func_logits = self.classifier(rep)
76
+ probs = torch.sigmoid(logits)
77
+ func_probs = torch.softmax(func_logits, dim=-1)
78
+ return probs, func_probs
statement_t5_tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
statement_t5_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<extra_id_99>",
5
+ "lstrip": true,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<extra_id_98>",
12
+ "lstrip": true,
13
+ "normalized": true,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ {
18
+ "content": "<extra_id_97>",
19
+ "lstrip": true,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ {
25
+ "content": "<extra_id_96>",
26
+ "lstrip": true,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ {
32
+ "content": "<extra_id_95>",
33
+ "lstrip": true,
34
+ "normalized": true,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ },
38
+ {
39
+ "content": "<extra_id_94>",
40
+ "lstrip": true,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false
44
+ },
45
+ {
46
+ "content": "<extra_id_93>",
47
+ "lstrip": true,
48
+ "normalized": true,
49
+ "rstrip": false,
50
+ "single_word": false
51
+ },
52
+ {
53
+ "content": "<extra_id_92>",
54
+ "lstrip": true,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false
58
+ },
59
+ {
60
+ "content": "<extra_id_91>",
61
+ "lstrip": true,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false
65
+ },
66
+ {
67
+ "content": "<extra_id_90>",
68
+ "lstrip": true,
69
+ "normalized": true,
70
+ "rstrip": false,
71
+ "single_word": false
72
+ },
73
+ {
74
+ "content": "<extra_id_89>",
75
+ "lstrip": true,
76
+ "normalized": true,
77
+ "rstrip": false,
78
+ "single_word": false
79
+ },
80
+ {
81
+ "content": "<extra_id_88>",
82
+ "lstrip": true,
83
+ "normalized": true,
84
+ "rstrip": false,
85
+ "single_word": false
86
+ },
87
+ {
88
+ "content": "<extra_id_87>",
89
+ "lstrip": true,
90
+ "normalized": true,
91
+ "rstrip": false,
92
+ "single_word": false
93
+ },
94
+ {
95
+ "content": "<extra_id_86>",
96
+ "lstrip": true,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false
100
+ },
101
+ {
102
+ "content": "<extra_id_85>",
103
+ "lstrip": true,
104
+ "normalized": true,
105
+ "rstrip": false,
106
+ "single_word": false
107
+ },
108
+ {
109
+ "content": "<extra_id_84>",
110
+ "lstrip": true,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false
114
+ },
115
+ {
116
+ "content": "<extra_id_83>",
117
+ "lstrip": true,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false
121
+ },
122
+ {
123
+ "content": "<extra_id_82>",
124
+ "lstrip": true,
125
+ "normalized": true,
126
+ "rstrip": false,
127
+ "single_word": false
128
+ },
129
+ {
130
+ "content": "<extra_id_81>",
131
+ "lstrip": true,
132
+ "normalized": true,
133
+ "rstrip": false,
134
+ "single_word": false
135
+ },
136
+ {
137
+ "content": "<extra_id_80>",
138
+ "lstrip": true,
139
+ "normalized": true,
140
+ "rstrip": false,
141
+ "single_word": false
142
+ },
143
+ {
144
+ "content": "<extra_id_79>",
145
+ "lstrip": true,
146
+ "normalized": true,
147
+ "rstrip": false,
148
+ "single_word": false
149
+ },
150
+ {
151
+ "content": "<extra_id_78>",
152
+ "lstrip": true,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false
156
+ },
157
+ {
158
+ "content": "<extra_id_77>",
159
+ "lstrip": true,
160
+ "normalized": true,
161
+ "rstrip": false,
162
+ "single_word": false
163
+ },
164
+ {
165
+ "content": "<extra_id_76>",
166
+ "lstrip": true,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false
170
+ },
171
+ {
172
+ "content": "<extra_id_75>",
173
+ "lstrip": true,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false
177
+ },
178
+ {
179
+ "content": "<extra_id_74>",
180
+ "lstrip": true,
181
+ "normalized": true,
182
+ "rstrip": false,
183
+ "single_word": false
184
+ },
185
+ {
186
+ "content": "<extra_id_73>",
187
+ "lstrip": true,
188
+ "normalized": true,
189
+ "rstrip": false,
190
+ "single_word": false
191
+ },
192
+ {
193
+ "content": "<extra_id_72>",
194
+ "lstrip": true,
195
+ "normalized": true,
196
+ "rstrip": false,
197
+ "single_word": false
198
+ },
199
+ {
200
+ "content": "<extra_id_71>",
201
+ "lstrip": true,
202
+ "normalized": true,
203
+ "rstrip": false,
204
+ "single_word": false
205
+ },
206
+ {
207
+ "content": "<extra_id_70>",
208
+ "lstrip": true,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false
212
+ },
213
+ {
214
+ "content": "<extra_id_69>",
215
+ "lstrip": true,
216
+ "normalized": true,
217
+ "rstrip": false,
218
+ "single_word": false
219
+ },
220
+ {
221
+ "content": "<extra_id_68>",
222
+ "lstrip": true,
223
+ "normalized": true,
224
+ "rstrip": false,
225
+ "single_word": false
226
+ },
227
+ {
228
+ "content": "<extra_id_67>",
229
+ "lstrip": true,
230
+ "normalized": true,
231
+ "rstrip": false,
232
+ "single_word": false
233
+ },
234
+ {
235
+ "content": "<extra_id_66>",
236
+ "lstrip": true,
237
+ "normalized": true,
238
+ "rstrip": false,
239
+ "single_word": false
240
+ },
241
+ {
242
+ "content": "<extra_id_65>",
243
+ "lstrip": true,
244
+ "normalized": true,
245
+ "rstrip": false,
246
+ "single_word": false
247
+ },
248
+ {
249
+ "content": "<extra_id_64>",
250
+ "lstrip": true,
251
+ "normalized": true,
252
+ "rstrip": false,
253
+ "single_word": false
254
+ },
255
+ {
256
+ "content": "<extra_id_63>",
257
+ "lstrip": true,
258
+ "normalized": true,
259
+ "rstrip": false,
260
+ "single_word": false
261
+ },
262
+ {
263
+ "content": "<extra_id_62>",
264
+ "lstrip": true,
265
+ "normalized": true,
266
+ "rstrip": false,
267
+ "single_word": false
268
+ },
269
+ {
270
+ "content": "<extra_id_61>",
271
+ "lstrip": true,
272
+ "normalized": true,
273
+ "rstrip": false,
274
+ "single_word": false
275
+ },
276
+ {
277
+ "content": "<extra_id_60>",
278
+ "lstrip": true,
279
+ "normalized": true,
280
+ "rstrip": false,
281
+ "single_word": false
282
+ },
283
+ {
284
+ "content": "<extra_id_59>",
285
+ "lstrip": true,
286
+ "normalized": true,
287
+ "rstrip": false,
288
+ "single_word": false
289
+ },
290
+ {
291
+ "content": "<extra_id_58>",
292
+ "lstrip": true,
293
+ "normalized": true,
294
+ "rstrip": false,
295
+ "single_word": false
296
+ },
297
+ {
298
+ "content": "<extra_id_57>",
299
+ "lstrip": true,
300
+ "normalized": true,
301
+ "rstrip": false,
302
+ "single_word": false
303
+ },
304
+ {
305
+ "content": "<extra_id_56>",
306
+ "lstrip": true,
307
+ "normalized": true,
308
+ "rstrip": false,
309
+ "single_word": false
310
+ },
311
+ {
312
+ "content": "<extra_id_55>",
313
+ "lstrip": true,
314
+ "normalized": true,
315
+ "rstrip": false,
316
+ "single_word": false
317
+ },
318
+ {
319
+ "content": "<extra_id_54>",
320
+ "lstrip": true,
321
+ "normalized": true,
322
+ "rstrip": false,
323
+ "single_word": false
324
+ },
325
+ {
326
+ "content": "<extra_id_53>",
327
+ "lstrip": true,
328
+ "normalized": true,
329
+ "rstrip": false,
330
+ "single_word": false
331
+ },
332
+ {
333
+ "content": "<extra_id_52>",
334
+ "lstrip": true,
335
+ "normalized": true,
336
+ "rstrip": false,
337
+ "single_word": false
338
+ },
339
+ {
340
+ "content": "<extra_id_51>",
341
+ "lstrip": true,
342
+ "normalized": true,
343
+ "rstrip": false,
344
+ "single_word": false
345
+ },
346
+ {
347
+ "content": "<extra_id_50>",
348
+ "lstrip": true,
349
+ "normalized": true,
350
+ "rstrip": false,
351
+ "single_word": false
352
+ },
353
+ {
354
+ "content": "<extra_id_49>",
355
+ "lstrip": true,
356
+ "normalized": true,
357
+ "rstrip": false,
358
+ "single_word": false
359
+ },
360
+ {
361
+ "content": "<extra_id_48>",
362
+ "lstrip": true,
363
+ "normalized": true,
364
+ "rstrip": false,
365
+ "single_word": false
366
+ },
367
+ {
368
+ "content": "<extra_id_47>",
369
+ "lstrip": true,
370
+ "normalized": true,
371
+ "rstrip": false,
372
+ "single_word": false
373
+ },
374
+ {
375
+ "content": "<extra_id_46>",
376
+ "lstrip": true,
377
+ "normalized": true,
378
+ "rstrip": false,
379
+ "single_word": false
380
+ },
381
+ {
382
+ "content": "<extra_id_45>",
383
+ "lstrip": true,
384
+ "normalized": true,
385
+ "rstrip": false,
386
+ "single_word": false
387
+ },
388
+ {
389
+ "content": "<extra_id_44>",
390
+ "lstrip": true,
391
+ "normalized": true,
392
+ "rstrip": false,
393
+ "single_word": false
394
+ },
395
+ {
396
+ "content": "<extra_id_43>",
397
+ "lstrip": true,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false
401
+ },
402
+ {
403
+ "content": "<extra_id_42>",
404
+ "lstrip": true,
405
+ "normalized": true,
406
+ "rstrip": false,
407
+ "single_word": false
408
+ },
409
+ {
410
+ "content": "<extra_id_41>",
411
+ "lstrip": true,
412
+ "normalized": true,
413
+ "rstrip": false,
414
+ "single_word": false
415
+ },
416
+ {
417
+ "content": "<extra_id_40>",
418
+ "lstrip": true,
419
+ "normalized": true,
420
+ "rstrip": false,
421
+ "single_word": false
422
+ },
423
+ {
424
+ "content": "<extra_id_39>",
425
+ "lstrip": true,
426
+ "normalized": true,
427
+ "rstrip": false,
428
+ "single_word": false
429
+ },
430
+ {
431
+ "content": "<extra_id_38>",
432
+ "lstrip": true,
433
+ "normalized": true,
434
+ "rstrip": false,
435
+ "single_word": false
436
+ },
437
+ {
438
+ "content": "<extra_id_37>",
439
+ "lstrip": true,
440
+ "normalized": true,
441
+ "rstrip": false,
442
+ "single_word": false
443
+ },
444
+ {
445
+ "content": "<extra_id_36>",
446
+ "lstrip": true,
447
+ "normalized": true,
448
+ "rstrip": false,
449
+ "single_word": false
450
+ },
451
+ {
452
+ "content": "<extra_id_35>",
453
+ "lstrip": true,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false
457
+ },
458
+ {
459
+ "content": "<extra_id_34>",
460
+ "lstrip": true,
461
+ "normalized": true,
462
+ "rstrip": false,
463
+ "single_word": false
464
+ },
465
+ {
466
+ "content": "<extra_id_33>",
467
+ "lstrip": true,
468
+ "normalized": true,
469
+ "rstrip": false,
470
+ "single_word": false
471
+ },
472
+ {
473
+ "content": "<extra_id_32>",
474
+ "lstrip": true,
475
+ "normalized": true,
476
+ "rstrip": false,
477
+ "single_word": false
478
+ },
479
+ {
480
+ "content": "<extra_id_31>",
481
+ "lstrip": true,
482
+ "normalized": true,
483
+ "rstrip": false,
484
+ "single_word": false
485
+ },
486
+ {
487
+ "content": "<extra_id_30>",
488
+ "lstrip": true,
489
+ "normalized": true,
490
+ "rstrip": false,
491
+ "single_word": false
492
+ },
493
+ {
494
+ "content": "<extra_id_29>",
495
+ "lstrip": true,
496
+ "normalized": true,
497
+ "rstrip": false,
498
+ "single_word": false
499
+ },
500
+ {
501
+ "content": "<extra_id_28>",
502
+ "lstrip": true,
503
+ "normalized": true,
504
+ "rstrip": false,
505
+ "single_word": false
506
+ },
507
+ {
508
+ "content": "<extra_id_27>",
509
+ "lstrip": true,
510
+ "normalized": true,
511
+ "rstrip": false,
512
+ "single_word": false
513
+ },
514
+ {
515
+ "content": "<extra_id_26>",
516
+ "lstrip": true,
517
+ "normalized": true,
518
+ "rstrip": false,
519
+ "single_word": false
520
+ },
521
+ {
522
+ "content": "<extra_id_25>",
523
+ "lstrip": true,
524
+ "normalized": true,
525
+ "rstrip": false,
526
+ "single_word": false
527
+ },
528
+ {
529
+ "content": "<extra_id_24>",
530
+ "lstrip": true,
531
+ "normalized": true,
532
+ "rstrip": false,
533
+ "single_word": false
534
+ },
535
+ {
536
+ "content": "<extra_id_23>",
537
+ "lstrip": true,
538
+ "normalized": true,
539
+ "rstrip": false,
540
+ "single_word": false
541
+ },
542
+ {
543
+ "content": "<extra_id_22>",
544
+ "lstrip": true,
545
+ "normalized": true,
546
+ "rstrip": false,
547
+ "single_word": false
548
+ },
549
+ {
550
+ "content": "<extra_id_21>",
551
+ "lstrip": true,
552
+ "normalized": true,
553
+ "rstrip": false,
554
+ "single_word": false
555
+ },
556
+ {
557
+ "content": "<extra_id_20>",
558
+ "lstrip": true,
559
+ "normalized": true,
560
+ "rstrip": false,
561
+ "single_word": false
562
+ },
563
+ {
564
+ "content": "<extra_id_19>",
565
+ "lstrip": true,
566
+ "normalized": true,
567
+ "rstrip": false,
568
+ "single_word": false
569
+ },
570
+ {
571
+ "content": "<extra_id_18>",
572
+ "lstrip": true,
573
+ "normalized": true,
574
+ "rstrip": false,
575
+ "single_word": false
576
+ },
577
+ {
578
+ "content": "<extra_id_17>",
579
+ "lstrip": true,
580
+ "normalized": true,
581
+ "rstrip": false,
582
+ "single_word": false
583
+ },
584
+ {
585
+ "content": "<extra_id_16>",
586
+ "lstrip": true,
587
+ "normalized": true,
588
+ "rstrip": false,
589
+ "single_word": false
590
+ },
591
+ {
592
+ "content": "<extra_id_15>",
593
+ "lstrip": true,
594
+ "normalized": true,
595
+ "rstrip": false,
596
+ "single_word": false
597
+ },
598
+ {
599
+ "content": "<extra_id_14>",
600
+ "lstrip": true,
601
+ "normalized": true,
602
+ "rstrip": false,
603
+ "single_word": false
604
+ },
605
+ {
606
+ "content": "<extra_id_13>",
607
+ "lstrip": true,
608
+ "normalized": true,
609
+ "rstrip": false,
610
+ "single_word": false
611
+ },
612
+ {
613
+ "content": "<extra_id_12>",
614
+ "lstrip": true,
615
+ "normalized": true,
616
+ "rstrip": false,
617
+ "single_word": false
618
+ },
619
+ {
620
+ "content": "<extra_id_11>",
621
+ "lstrip": true,
622
+ "normalized": true,
623
+ "rstrip": false,
624
+ "single_word": false
625
+ },
626
+ {
627
+ "content": "<extra_id_10>",
628
+ "lstrip": true,
629
+ "normalized": true,
630
+ "rstrip": false,
631
+ "single_word": false
632
+ },
633
+ {
634
+ "content": "<extra_id_9>",
635
+ "lstrip": true,
636
+ "normalized": true,
637
+ "rstrip": false,
638
+ "single_word": false
639
+ },
640
+ {
641
+ "content": "<extra_id_8>",
642
+ "lstrip": true,
643
+ "normalized": true,
644
+ "rstrip": false,
645
+ "single_word": false
646
+ },
647
+ {
648
+ "content": "<extra_id_7>",
649
+ "lstrip": true,
650
+ "normalized": true,
651
+ "rstrip": false,
652
+ "single_word": false
653
+ },
654
+ {
655
+ "content": "<extra_id_6>",
656
+ "lstrip": true,
657
+ "normalized": true,
658
+ "rstrip": false,
659
+ "single_word": false
660
+ },
661
+ {
662
+ "content": "<extra_id_5>",
663
+ "lstrip": true,
664
+ "normalized": true,
665
+ "rstrip": false,
666
+ "single_word": false
667
+ },
668
+ {
669
+ "content": "<extra_id_4>",
670
+ "lstrip": true,
671
+ "normalized": true,
672
+ "rstrip": false,
673
+ "single_word": false
674
+ },
675
+ {
676
+ "content": "<extra_id_3>",
677
+ "lstrip": true,
678
+ "normalized": true,
679
+ "rstrip": false,
680
+ "single_word": false
681
+ },
682
+ {
683
+ "content": "<extra_id_2>",
684
+ "lstrip": true,
685
+ "normalized": true,
686
+ "rstrip": false,
687
+ "single_word": false
688
+ },
689
+ {
690
+ "content": "<extra_id_1>",
691
+ "lstrip": true,
692
+ "normalized": true,
693
+ "rstrip": false,
694
+ "single_word": false
695
+ },
696
+ {
697
+ "content": "<extra_id_0>",
698
+ "lstrip": true,
699
+ "normalized": true,
700
+ "rstrip": false,
701
+ "single_word": false
702
+ }
703
+ ],
704
+ "bos_token": {
705
+ "content": "<s>",
706
+ "lstrip": false,
707
+ "normalized": true,
708
+ "rstrip": false,
709
+ "single_word": false
710
+ },
711
+ "cls_token": {
712
+ "content": "<s>",
713
+ "lstrip": false,
714
+ "normalized": true,
715
+ "rstrip": false,
716
+ "single_word": false
717
+ },
718
+ "eos_token": {
719
+ "content": "</s>",
720
+ "lstrip": false,
721
+ "normalized": true,
722
+ "rstrip": false,
723
+ "single_word": false
724
+ },
725
+ "mask_token": {
726
+ "content": "<mask>",
727
+ "lstrip": true,
728
+ "normalized": true,
729
+ "rstrip": false,
730
+ "single_word": false
731
+ },
732
+ "pad_token": {
733
+ "content": "<pad>",
734
+ "lstrip": false,
735
+ "normalized": true,
736
+ "rstrip": false,
737
+ "single_word": false
738
+ },
739
+ "sep_token": {
740
+ "content": "</s>",
741
+ "lstrip": false,
742
+ "normalized": true,
743
+ "rstrip": false,
744
+ "single_word": false
745
+ },
746
+ "unk_token": {
747
+ "content": "<unk>",
748
+ "lstrip": false,
749
+ "normalized": true,
750
+ "rstrip": false,
751
+ "single_word": false
752
+ }
753
+ }
statement_t5_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "cls_token": {
12
+ "__type": "AddedToken",
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "__type": "AddedToken",
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "errors": "replace",
28
+ "mask_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<mask>",
31
+ "lstrip": true,
32
+ "normalized": true,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ },
36
+ "model_max_length": 512,
37
+ "name_or_path": "Salesforce/codet5-base",
38
+ "pad_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<pad>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ },
46
+ "sep_token": {
47
+ "__type": "AddedToken",
48
+ "content": "</s>",
49
+ "lstrip": false,
50
+ "normalized": true,
51
+ "rstrip": false,
52
+ "single_word": false
53
+ },
54
+ "special_tokens_map_file": "/home/michael/.cache/huggingface/transformers/5941df5e4315c5ab63b7b2ac791fb0bf0f209744a055c06b43b5274849137cdd.b9905d0575bde443a20834122b6e2d48e853b2e36444ce98ddeb43c38097eb3f",
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "unk_token": {
57
+ "__type": "AddedToken",
58
+ "content": "<unk>",
59
+ "lstrip": false,
60
+ "normalized": true,
61
+ "rstrip": false,
62
+ "single_word": false
63
+ }
64
+ }
statement_t5_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
t5_config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Salesforce/codet5-base",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "bos_token_id": 1,
7
+ "d_ff": 3072,
8
+ "d_kv": 64,
9
+ "d_model": 768,
10
+ "decoder_start_token_id": 0,
11
+ "dense_act_fn": "relu",
12
+ "dropout_rate": 0.1,
13
+ "eos_token_id": 2,
14
+ "feed_forward_proj": "relu",
15
+ "gradient_checkpointing": false,
16
+ "id2label": {
17
+ "0": "LABEL_0"
18
+ },
19
+ "initializer_factor": 1.0,
20
+ "is_encoder_decoder": true,
21
+ "is_gated_act": false,
22
+ "label2id": {
23
+ "LABEL_0": 0
24
+ },
25
+ "layer_norm_epsilon": 1e-06,
26
+ "model_type": "t5",
27
+ "n_positions": 512,
28
+ "num_decoder_layers": 12,
29
+ "num_heads": 12,
30
+ "num_layers": 12,
31
+ "output_past": true,
32
+ "pad_token_id": 0,
33
+ "relative_attention_max_distance": 128,
34
+ "relative_attention_num_buckets": 32,
35
+ "task_specific_params": {
36
+ "summarization": {
37
+ "early_stopping": true,
38
+ "length_penalty": 2.0,
39
+ "max_length": 200,
40
+ "min_length": 30,
41
+ "no_repeat_ngram_size": 3,
42
+ "num_beams": 4,
43
+ "prefix": "summarize: "
44
+ },
45
+ "translation_en_to_de": {
46
+ "early_stopping": true,
47
+ "max_length": 300,
48
+ "num_beams": 4,
49
+ "prefix": "translate English to German: "
50
+ },
51
+ "translation_en_to_fr": {
52
+ "early_stopping": true,
53
+ "max_length": 300,
54
+ "num_beams": 4,
55
+ "prefix": "translate English to French: "
56
+ },
57
+ "translation_en_to_ro": {
58
+ "early_stopping": true,
59
+ "max_length": 300,
60
+ "num_beams": 4,
61
+ "prefix": "translate English to Romanian: "
62
+ }
63
+ },
64
+ "torch_dtype": "float32",
65
+ "transformers_version": "4.27.3",
66
+ "use_cache": true,
67
+ "vocab_size": 32100
68
+ }
utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer, T5Config, T5EncoderModel
2
+ from statement_t5 import StatementT5
3
+ import torch
4
+ import pickle
5
+ import numpy as np
6
+ import onnxruntime
7
+
8
+ def to_numpy(tensor):
9
+ """ get np input for onnx runtime model """
10
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
11
+
12
+ def predict_vul_lines(code: list, gpu: bool = False) -> dict:
13
+ """Generate statement-level and function-level vulnerability prediction probabilities.
14
+ Parameters
15
+ ----------
16
+ code : :obj:`list`
17
+ A list of String functions.
18
+ gpu : bool
19
+ Defines if CUDA inference is enabled
20
+ Returns
21
+ -------
22
+ :obj:`dict`
23
+ A dictionary with two keys, "batch_vul_pred", "batch_vul_pred_prob", and "batch_line_scores"
24
+ "batch_func_pred" stores a list of function-level vulnerability prediction: [0, 1, ...] where 0 means non-vulnerable and 1 means vulnerable
25
+ "batch_func_pred_prob" stores a list of function-level vulnerability prediction probabilities [0.89, 0.75, ...] corresponding to "batch_func_pred"
26
+ "batch_statement_pred" stores a list of statement-level vulnerability prediction: [0, 1, ...] where 0 means non-vulnerable and 1 means vulnerable
27
+ "batch_statement_pred_prob" stores a list of statement-level vulnerability prediction probabilities [0.89, 0.75, ...] corresponding to "batch_statement_pred"
28
+ """
29
+ MAX_STATEMENTS = 155
30
+ MAX_STATEMENT_LENGTH = 20
31
+ DEVICE = 'cuda' if gpu else 'cpu'
32
+ # load tokenizer
33
+ tokenizer = RobertaTokenizer.from_pretrained("./statement_t5_tokenizer")
34
+ # load model
35
+ config = T5Config.from_pretrained("./t5_config.json")
36
+ model = T5EncoderModel(config=config)
37
+ model = StatementT5(model, tokenizer, device=DEVICE)
38
+ output_dir = "./models/statement_t5_model.bin"
39
+ model.load_state_dict(torch.load(output_dir, map_location=DEVICE))
40
+ model.to(DEVICE)
41
+ model.eval()
42
+ input_ids, statement_mask = statement_tokenization(code, MAX_STATEMENTS, MAX_STATEMENT_LENGTH, tokenizer)
43
+ with torch.no_grad():
44
+ statement_probs, func_probs = model(input_ids=input_ids, statement_mask=statement_mask)
45
+ func_preds = torch.argmax(func_probs, dim=-1)
46
+ statement_preds = torch.where(statement_probs>0.5, 1, 0)
47
+ return {"batch_func_pred": func_preds, "batch_func_pred_prob": func_probs,
48
+ "batch_statement_pred": statement_preds, "batch_statement_pred_prob": statement_probs}
49
+
50
+ def statement_tokenization(code: list, max_statements: int, max_statement_length: int, tokenizer):
51
+ batch_input_ids = []
52
+ batch_statement_mask = []
53
+ for c in code:
54
+ source = c.split("\n")
55
+ source = [statement for statement in source if statement != ""]
56
+
57
+ source = source[:max_statements]
58
+ padding_statement = [tokenizer.pad_token_id for _ in range(20)]
59
+
60
+ input_ids = []
61
+ for stat in source:
62
+ ids_ = tokenizer.encode(str(stat),
63
+ truncation=True,
64
+ max_length=max_statement_length,
65
+ padding='max_length',
66
+ add_special_tokens=False)
67
+ input_ids.append(ids_)
68
+ if len(input_ids) < max_statements:
69
+ for _ in range(max_statements-len(input_ids)):
70
+ input_ids.append(padding_statement)
71
+ statement_mask = []
72
+ for statement in input_ids:
73
+ if statement == padding_statement:
74
+ statement_mask.append(0)
75
+ else:
76
+ statement_mask.append(1)
77
+ batch_input_ids.append(input_ids)
78
+ batch_statement_mask.append(statement_mask)
79
+ return torch.tensor(batch_input_ids), torch.tensor(batch_statement_mask)
80
+
81
+ def predict_cweid(code: list, gpu: bool = False) -> dict:
82
+ """Generate CWE-IDs and CWE Abstract Types Predictions.
83
+ Parameters
84
+ ----------
85
+ code : :obj:`list`
86
+ A list of String functions.
87
+ gpu : bool
88
+ Defines if CUDA inference is enabled
89
+ Returns
90
+ -------
91
+ :obj:`dict`
92
+ A dictionary with four keys, "cwe_id", "cwe_id_prob", "cwe_type", "cwe_type_prob"
93
+ "cwe_id" stores a list of CWE-ID predictions: [CWE-787, CWE-119, ...]
94
+ "cwe_id_prob" stores a list of confidence scores of CWE-ID predictions [0.9, 0.7, ...]
95
+ "cwe_type" stores a list of CWE abstract types predictions: ["Base", "Class", ...]
96
+ "cwe_type_prob" stores a list of confidence scores of CWE abstract types predictions [0.9, 0.7, ...]
97
+ """
98
+ provider = ["CUDAExecutionProvider", "CPUExecutionProvider"] if gpu else ["CPUExecutionProvider"]
99
+ with open("./inference-common/label_map.pkl", "rb") as f:
100
+ cwe_id_map, cwe_type_map = pickle.load(f)
101
+ # load tokenizer
102
+ tokenizer = RobertaTokenizer.from_pretrained("./inference-common/tokenizer")
103
+ tokenizer.add_tokens(["<cls_type>"])
104
+ tokenizer.cls_type_token = "<cls_type>"
105
+ model_input = []
106
+ for c in code:
107
+ code_tokens = tokenizer.tokenize(str(c))[:512 - 3]
108
+ source_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.cls_type_token] + [tokenizer.sep_token]
109
+ input_ids = tokenizer.convert_tokens_to_ids(source_tokens)
110
+ padding_length = 512 - len(input_ids)
111
+ input_ids += [tokenizer.pad_token_id] * padding_length
112
+ model_input.append(input_ids)
113
+ device = "cuda" if gpu else "cpu"
114
+ model_input = torch.tensor(model_input, device=device)
115
+ # onnx runtime session
116
+ ort_session = onnxruntime.InferenceSession("./models/cwe_model.onnx", providers=provider)
117
+ # compute ONNX Runtime output prediction
118
+ ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model_input)}
119
+ cwe_id_prob, cwe_type_prob = ort_session.run(None, ort_inputs)
120
+ # batch_cwe_id_pred (1D list with shape of [batch size]): [pred_1, pred_2, ..., pred_n]
121
+ batch_cwe_id = np.argmax(cwe_id_prob, axis=-1).tolist()
122
+ # map predicted idx back to CWE-ID
123
+ batch_cwe_id_pred = [cwe_id_map[str(idx)] for idx in batch_cwe_id]
124
+ # batch_cwe_id_pred_prob (1D list with shape of [batch_size]): [prob_1, prob_2, ..., prob_n]
125
+ batch_cwe_id_pred_prob = []
126
+ for i in range(len(cwe_id_prob)):
127
+ batch_cwe_id_pred_prob.append(cwe_id_prob[i][batch_cwe_id[i]].item())
128
+ # batch_cwe_type_pred (1D list with shape of [batch size]): [pred_1, pred_2, ..., pred_n]
129
+ batch_cwe_type = np.argmax(cwe_type_prob, axis=-1).tolist()
130
+ # map predicted idx back to CWE-Type
131
+ batch_cwe_type_pred = [cwe_type_map[str(idx)] for idx in batch_cwe_type]
132
+ # batch_cwe_type_pred_prob (1D list with shape of [batch_size]): [prob_1, prob_2, ..., prob_n]
133
+ batch_cwe_type_pred_prob = []
134
+ for i in range(len(cwe_type_prob)):
135
+ batch_cwe_type_pred_prob.append(cwe_type_prob[i][batch_cwe_type[i]].item())
136
+ return {"cwe_id": batch_cwe_id_pred,
137
+ "cwe_id_prob": batch_cwe_id_pred_prob,
138
+ "cwe_type": batch_cwe_type_pred,
139
+ "cwe_type_prob": batch_cwe_type_pred_prob}
140
+
141
+ def predict_sev(code: list, gpu: bool = False) -> dict:
142
+ """Generate CVSS severity score predictions.
143
+ Parameters
144
+ ----------
145
+ code : :obj:`list`
146
+ A list of String functions.
147
+ gpu : bool
148
+ Defines if CUDA inference is enabled
149
+ Returns
150
+ -------
151
+ :obj:`dict`
152
+ A dictionary with two keys, "batch_sev_score", "batch_sev_class"
153
+ "batch_sev_score" stores a list of severity score prediction: [1.0, 5.0, 9.0 ...]
154
+ "batch_sev_class" stores a list of severity class based on predicted severity score ["Medium", "Critical"...]
155
+ """
156
+ provider = ["CUDAExecutionProvider", "CPUExecutionProvider"] if gpu else ["CPUExecutionProvider"]
157
+ # load tokenizer
158
+ tokenizer = RobertaTokenizer.from_pretrained("./inference-common/tokenizer")
159
+ model_input = tokenizer(code, truncation=True, max_length=512, padding='max_length',
160
+ return_tensors="pt").input_ids
161
+ # onnx runtime session
162
+ ort_session = onnxruntime.InferenceSession("./models/sev_model.onnx", providers=provider)
163
+ # compute ONNX Runtime output prediction
164
+ ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model_input)}
165
+ cvss_score = ort_session.run(None, ort_inputs)
166
+ batch_sev_score = list(cvss_score[0].flatten().tolist())
167
+ batch_sev_class = []
168
+ for i in range(len(batch_sev_score)):
169
+ if batch_sev_score[i] == 0:
170
+ batch_sev_class.append("None")
171
+ elif batch_sev_score[i] < 4:
172
+ batch_sev_class.append("Low")
173
+ elif batch_sev_score[i] < 7:
174
+ batch_sev_class.append("Medium")
175
+ elif batch_sev_score[i] < 9:
176
+ batch_sev_class.append("High")
177
+ else:
178
+ batch_sev_class.append("Critical")
179
+ return {"batch_sev_score": batch_sev_score, "batch_sev_class": batch_sev_class}
180
+
181
+ def predict(code: list):
182
+ vul_preds = predict_vul_lines(code)
183
+ cwe_preds = predict_cweid(code)
184
+ sev_preds = predict_sev(code)
185
+
186
+ if __name__ == "__main__":
187
+ import pandas as pd
188
+ df = pd.read_csv("./data/processed_test.csv")
189
+ funcs = df["func_before"].tolist()
190
+ for code in funcs:
191
+ out = predict_vul_lines([code])
192
+ print(out["batch_func_pred"][0])