ajaxwin commited on
Commit
c719864
·
1 Parent(s): cf983b8

Task1 actions reviewed

Browse files

Task1 actions separated

tasks/task1/actions.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Actions for Task 1: Targeted Vulnerability Detection.
2
+ Actions & rewards:
3
+ list_functions -0.05 (broad overview of contract)
4
+ get_function_code -0.10 (wrong function) / +0.05 (correct function)
5
+ get_function_summary -0.05 (wrong function) / +0.03 (correct function)
6
+ get_file_metadata -0.04 (general contract info)
7
+ """
8
+
9
+ from typing import Any, Dict, Tuple
10
+ from env.schemas import Reward
11
+ from data.data_loader import (
12
+ list_function_names,
13
+ get_function_by_name,
14
+ list_state_variable_names,
15
+ get_state_variable_by_name,
16
+ )
17
+
18
+
19
+ def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
20
+ """Handle LIST_FUNCTIONS action."""
21
+ if ctx._is_repeated(qkey):
22
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
23
+ names = list_function_names(ctx._contract)
24
+ return (
25
+ f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}",
26
+ Reward(value=-0.05, reason="list_functions cost", partial=True),
27
+ )
28
+
29
+
30
+ def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
31
+ """Handle GET_FUNCTION_CODE action."""
32
+ fn_name = params.get("function_name", "")
33
+ if ctx._is_repeated(qkey):
34
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
35
+ fn = get_function_by_name(ctx._contract, fn_name)
36
+ if fn is None:
37
+ return (
38
+ f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}",
39
+ Reward(value=-0.10, reason="Wrong/unknown function name", partial=True),
40
+ )
41
+ is_target = fn["name"].lower() == ctx._target_fn["name"].lower()
42
+ code = fn.get("code", "// no code available")
43
+ reward_val = 0.05 if is_target else -0.10
44
+ reason = "Fetched target function code (+)" if is_target else "Fetched non-target function (-)"
45
+ return (
46
+ f"// {fn['name']}\n{code}",
47
+ Reward(value=reward_val, reason=reason, partial=True),
48
+ )
49
+
50
+
51
+ def get_function_summary(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
52
+ """Handle GET_FUNCTION_SUMMARY action."""
53
+ fn_name = params.get("function_name", "")
54
+ if ctx._is_repeated(qkey):
55
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
56
+ fn = get_function_by_name(ctx._contract, fn_name)
57
+ if fn is None:
58
+ return (
59
+ f"Function '{fn_name}' not found.",
60
+ Reward(value=-0.05, reason="Wrong function name", partial=True),
61
+ )
62
+ is_target = fn["name"].lower() == ctx._target_fn["name"].lower()
63
+ comment = fn.get("comment", "No summary available.")
64
+ reward_val = 0.03 if is_target else -0.05
65
+ reason = "Fetched target function summary (+)" if is_target else "Fetched non-target summary (-)"
66
+ return (
67
+ f"Summary of '{fn['name']}': {comment}",
68
+ Reward(value=reward_val, reason=reason, partial=True),
69
+ )
70
+
71
+
72
+ def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
73
+ """Handle GET_FILE_METADATA action."""
74
+ if ctx._is_repeated(qkey):
75
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
76
+ meta = ctx._contract.get("metadata", {})
77
+ result = (
78
+ f"Contract: {ctx._contract['contract_name']} | "
79
+ f"Solidity: {meta.get('solidity_version', 'N/A')} | "
80
+ f"Description: {meta.get('description', 'N/A')}"
81
+ )
82
+ return result, Reward(value=-0.04, reason="get_file_metadata cost", partial=True)
83
+
84
+
85
+ def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
86
+ """Handle GET_STATE_VARIABLE action."""
87
+ var_name = params.get("variable_name", "")
88
+ if ctx._is_repeated(qkey):
89
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
90
+ if not var_name:
91
+ names = list_state_variable_names(ctx._contract)
92
+ return (
93
+ f"State variables: {', '.join(names)}",
94
+ Reward(value=-0.05, reason="Listed state variables", partial=True),
95
+ )
96
+ sv = get_state_variable_by_name(ctx._contract, var_name)
97
+ if sv is None:
98
+ return (
99
+ f"Variable '{var_name}' not found.",
100
+ Reward(value=-0.05, reason="Unknown state variable", partial=True),
101
+ )
102
+ return (
103
+ f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
104
+ Reward(value=-0.05, reason="get_state_variable cost", partial=True),
105
+ )
106
+
107
+
108
+ def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
109
+ """Handle GET_CALL_GRAPH action."""
110
+ if ctx._is_repeated(qkey):
111
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
112
+ cg = ctx._contract.get("call_graph", {})
113
+ cg_str = "; ".join(f"{fn} → [{', '.join(callees)}]" for fn, callees in cg.items())
114
+ return (
115
+ f"Call graph: {cg_str}",
116
+ Reward(value=-0.08, reason="get_call_graph cost", partial=True),
117
+ )
118
+
119
+
120
+ def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
121
+ """Handle SUBMIT action."""
122
+ fn_name = params.get("function_name", "")
123
+ vuln_type = params.get("vulnerability_type", "")
124
+ if not fn_name or not vuln_type:
125
+ return (
126
+ "Submit requires 'function_name' and 'vulnerability_type' in params.",
127
+ Reward(value=-0.5, reason="Malformed submission", partial=True),
128
+ )
129
+ score = ctx._grader.grade_submission(fn_name, vuln_type)
130
+ reward_val = ctx._grader.reward_for_score(score)
131
+ ctx._done = True
132
+
133
+ if score == 1.0:
134
+ msg = (
135
+ f"✅ CORRECT! '{fn_name}' is the vulnerable function. "
136
+ f"Vulnerability type '{vuln_type}' matches. Score: 1.0"
137
+ )
138
+ elif score == 0.5:
139
+ msg = (
140
+ f"⚠️ PARTIAL. '{fn_name}' is the right function, but the vulnerability type "
141
+ f"'{vuln_type}' was not precise. Score: 0.5"
142
+ )
143
+ else:
144
+ correct = ctx._grader.get_canonical_answer()
145
+ msg = (
146
+ f"❌ INCORRECT. '{fn_name}' is not the target vulnerable function. "
147
+ f"Correct answer: {correct['function']} ({correct['vulnerability']}). Score: 0.0"
148
+ )
149
+ return msg, Reward(
150
+ value=reward_val,
151
+ reason=f"Submission score={score:.1f}",
152
+ partial=False,
153
+ )
154
+
155
+
156
+ def unknown_action(ctx: Any, qkey: str, params: Dict, action_type: str) -> Tuple[str, Reward]:
157
+ """Fallback for unknown actions."""
158
+ return (
159
+ f"Unknown action type: {action_type}",
160
+ Reward(value=-0.10, reason="Unknown action", partial=True),
161
+ )
tasks/task1/environment.py CHANGED
@@ -10,16 +10,16 @@ Episode flow:
10
  4. When the agent submits, the Grader scores the answer and the episode ends.
11
 
12
  Reward shaping:
13
- list_functions : -0.05
14
- get_function_code : -0.10 (wrong function) / +0.05 (correct function)
15
- get_function_summary : -0.05 (wrong function) / +0.03 (correct function)
16
- get_file_metadata : -0.04
17
- get_state_variable : -0.05
18
- get_call_graph : -0.08
19
- submit (score=1.0) : +5.0
20
- submit (score=0.5) : +1.0
21
- submit (score=0.0) : -1.5
22
- repeated query : -0.40
23
  """
24
 
25
  from __future__ import annotations
@@ -46,6 +46,7 @@ from env.schemas import (
46
  StepResult,
47
  )
48
  from tasks.task1.grader import Task1Grader
 
49
 
50
  TASK_ID = "task1_vuln_detection"
51
 
@@ -188,142 +189,19 @@ class Task1Environment(BaseEnv):
188
  params = action.params
189
  qkey = self._query_key(at, params)
190
 
191
- # ---- list_functions ----------------------------------------
192
- if at == ActionType.LIST_FUNCTIONS:
193
- if self._is_repeated(qkey):
194
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
195
- names = list_function_names(self._contract)
196
- return (
197
- f"Functions in {self._contract['contract_name']}: {', '.join(names)}",
198
- Reward(value=-0.05, reason="list_functions cost", partial=True),
199
- )
200
-
201
- # ---- get_function_code -------------------------------------
202
- if at == ActionType.GET_FUNCTION_CODE:
203
- fn_name = params.get("function_name", "")
204
- if self._is_repeated(qkey):
205
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
206
- fn = get_function_by_name(self._contract, fn_name)
207
- if fn is None:
208
- return (
209
- f"Function '{fn_name}' not found. Available: {list_function_names(self._contract)}",
210
- Reward(value=-0.10, reason="Wrong/unknown function name", partial=True),
211
- )
212
- is_target = fn["name"].lower() == self._target_fn["name"].lower()
213
- code = fn.get("code", "// no code available")
214
- reward_val = 0.05 if is_target else -0.10
215
- reason = "Fetched target function code (+)" if is_target else "Fetched non-target function (-)"
216
- return (
217
- f"// {fn['name']}\n{code}",
218
- Reward(value=reward_val, reason=reason, partial=True),
219
- )
220
-
221
- # ---- get_function_summary ----------------------------------
222
- if at == ActionType.GET_FUNCTION_SUMMARY:
223
- fn_name = params.get("function_name", "")
224
- if self._is_repeated(qkey):
225
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
226
- fn = get_function_by_name(self._contract, fn_name)
227
- if fn is None:
228
- return (
229
- f"Function '{fn_name}' not found.",
230
- Reward(value=-0.05, reason="Wrong function name", partial=True),
231
- )
232
- is_target = fn["name"].lower() == self._target_fn["name"].lower()
233
- comment = fn.get("comment", "No summary available.")
234
- reward_val = 0.03 if is_target else -0.05
235
- reason = "Fetched target function summary (+)" if is_target else "Fetched non-target summary (-)"
236
- return (
237
- f"Summary of '{fn['name']}': {comment}",
238
- Reward(value=reward_val, reason=reason, partial=True),
239
- )
240
-
241
- # ---- get_file_metadata -------------------------------------
242
- if at == ActionType.GET_FILE_METADATA:
243
- if self._is_repeated(qkey):
244
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
245
- meta = self._contract.get("metadata", {})
246
- result = (
247
- f"Contract: {self._contract['contract_name']} | "
248
- f"File: {self._contract.get('file_name', 'N/A')} | "
249
- f"Solidity: {meta.get('solidity_version', 'N/A')} | "
250
- f"License: {meta.get('license', 'N/A')} | "
251
- f"Author: {meta.get('author', 'N/A')} | "
252
- f"Description: {meta.get('description', 'N/A')}"
253
- )
254
- return result, Reward(value=-0.04, reason="get_file_metadata cost", partial=True)
255
-
256
- # ---- get_state_variable ------------------------------------
257
- if at == ActionType.GET_STATE_VARIABLE:
258
- var_name = params.get("variable_name", "")
259
- if self._is_repeated(qkey):
260
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
261
- if not var_name:
262
- # Return list of all state variables
263
- names = list_state_variable_names(self._contract)
264
- return (
265
- f"State variables: {', '.join(names)}",
266
- Reward(value=-0.05, reason="Listed state variables", partial=True),
267
- )
268
- sv = get_state_variable_by_name(self._contract, var_name)
269
- if sv is None:
270
- return (
271
- f"Variable '{var_name}' not found.",
272
- Reward(value=-0.05, reason="Unknown state variable", partial=True),
273
- )
274
- return (
275
- f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
276
- Reward(value=-0.05, reason="get_state_variable cost", partial=True),
277
- )
278
-
279
- # ---- get_call_graph ----------------------------------------
280
- if at == ActionType.GET_CALL_GRAPH:
281
- if self._is_repeated(qkey):
282
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
283
- cg = self._contract.get("call_graph", {})
284
- cg_str = "; ".join(f"{fn} → [{', '.join(callees)}]" for fn, callees in cg.items())
285
- return (
286
- f"Call graph: {cg_str}",
287
- Reward(value=-0.08, reason="get_call_graph cost", partial=True),
288
- )
289
-
290
- # ---- submit ------------------------------------------------
291
- if at == ActionType.SUBMIT:
292
- fn_name = params.get("function_name", "")
293
- vuln_type = params.get("vulnerability_type", "")
294
- if not fn_name or not vuln_type:
295
- return (
296
- "Submit requires 'function_name' and 'vulnerability_type' in params.",
297
- Reward(value=-0.5, reason="Malformed submission", partial=True),
298
- )
299
- score = self._grader.grade_submission(fn_name, vuln_type) # type: ignore
300
- reward_val = self._grader.reward_for_score(score) # type: ignore
301
- self._done = True
302
-
303
- if score == 1.0:
304
- msg = (
305
- f"✅ CORRECT! '{fn_name}' is the vulnerable function. "
306
- f"Vulnerability type '{vuln_type}' matches. Score: 1.0"
307
- )
308
- elif score == 0.5:
309
- msg = (
310
- f"⚠️ PARTIAL. '{fn_name}' is the right function, but the vulnerability type "
311
- f"'{vuln_type}' was not precise. Score: 0.5"
312
- )
313
- else:
314
- correct = self._grader.get_canonical_answer() # type: ignore
315
- msg = (
316
- f"❌ INCORRECT. '{fn_name}' is not the target vulnerable function. "
317
- f"Correct answer: {correct['function']} ({correct['vulnerability']}). Score: 0.0"
318
- )
319
- return msg, Reward(
320
- value=reward_val,
321
- reason=f"Submission score={score:.1f}",
322
- partial=False,
323
- )
324
-
325
- # ---- unknown action ----------------------------------------
326
- return (
327
- f"Unknown action type: {at}",
328
- Reward(value=-0.10, reason="Unknown action", partial=True),
329
- )
 
10
  4. When the agent submits, the Grader scores the answer and the episode ends.
11
 
12
  Reward shaping:
13
+ list_functions : -0.05
14
+ get_function_code : -0.10 (wrong function) / +0.05 (correct function)
15
+ get_function_summary : -0.05 (wrong function) / +0.03 (correct function)
16
+ get_file_metadata : -0.04
17
+ get_state_variable : -0.05
18
+ get_call_graph : -0.08
19
+ correct submit (score=1.0) : +5.0
20
+ partially correct submit (score=0.5) : +1.0
21
+ wrong submit (score=0.0) : -1.5
22
+ repeated query : -0.40
23
  """
24
 
25
  from __future__ import annotations
 
46
  StepResult,
47
  )
48
  from tasks.task1.grader import Task1Grader
49
+ from tasks.task1 import actions
50
 
51
  TASK_ID = "task1_vuln_detection"
52
 
 
189
  params = action.params
190
  qkey = self._query_key(at, params)
191
 
192
+ # Mapping from ActionType to handler function
193
+ handlers = {
194
+ ActionType.LIST_FUNCTIONS: actions.list_functions,
195
+ ActionType.GET_FUNCTION_CODE: actions.get_function_code,
196
+ ActionType.GET_FUNCTION_SUMMARY: actions.get_function_summary,
197
+ ActionType.GET_FILE_METADATA: actions.get_file_metadata,
198
+ ActionType.GET_STATE_VARIABLE: actions.get_state_variable,
199
+ ActionType.GET_CALL_GRAPH: actions.get_call_graph,
200
+ ActionType.SUBMIT: actions.submit,
201
+ }
202
+
203
+ handler = handlers.get(at)
204
+ if handler is None:
205
+ return actions.unknown_action(self, qkey, params, at)
206
+
207
+ return handler(self, qkey, params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks/task2/environment.py CHANGED
@@ -27,7 +27,6 @@ from __future__ import annotations
27
 
28
  import random
29
  from typing import Any, Dict, List, Optional, Set
30
- import actions
31
 
32
  from data.data_loader import load_contracts, sample_property_episode
33
  from env.base_env import BaseEnv
@@ -41,6 +40,7 @@ from env.schemas import (
41
  StepResult,
42
  )
43
  from tasks.task2.grader import Task2Grader
 
44
 
45
  TASK_ID = "task2_property_discovery"
46
  MAX_STEPS = 15
 
27
 
28
  import random
29
  from typing import Any, Dict, List, Optional, Set
 
30
 
31
  from data.data_loader import load_contracts, sample_property_episode
32
  from env.base_env import BaseEnv
 
40
  StepResult,
41
  )
42
  from tasks.task2.grader import Task2Grader
43
+ from tasks.task2 import actions
44
 
45
  TASK_ID = "task2_property_discovery"
46
  MAX_STEPS = 15