ajaxwin
refactor: Update ActionType to include costs and modified grader for task 1
5235476
"""Actions for Task 1: Targeted Vulnerability Detection.
"""
from typing import Any, Dict, Tuple
from env.schemas import ActionType, Reward
from data.data_loader import (
list_function_names,
get_function_by_name,
list_state_variable_names,
get_state_variable_by_name,
)
def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle LIST_FUNCTIONS action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
names = list_function_names(ctx._contract)
return (
f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}",
Reward(value=ActionType.LIST_FUNCTIONS.cost, reason="list_functions cost", partial=True),
)
def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_FUNCTION_CODE action."""
fn_name = params.get("function_name", "")
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
fn = get_function_by_name(ctx._contract, fn_name)
if fn is None:
return (
f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}",
Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Wrong/unknown function name", partial=True),
)
code = fn.get("code", "// no code available")
return (
f"// {fn['name']}\n{code}",
Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Fetched code", partial=True),
)
def get_function_summary(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_FUNCTION_SUMMARY action."""
fn_name = params.get("function_name", "")
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
fn = get_function_by_name(ctx._contract, fn_name)
if fn is None:
return (
f"Function '{fn_name}' not found.",
Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Wrong function name", partial=True),
)
comment = fn.get("comment", "No summary available.")
return (
f"Summary of '{fn['name']}': {comment}",
Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Fetched summary", partial=True),
)
def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_FILE_METADATA action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
meta = ctx._contract.get("metadata", {})
result = (
f"Contract: {ctx._contract['contract_name']} | "
f"Solidity: {meta.get('solidity_version', 'N/A')} | "
f"Description: {meta.get('description', 'N/A')}"
)
return result, Reward(value=ActionType.GET_FILE_METADATA.cost, reason="get_file_metadata cost", partial=True)
def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_STATE_VARIABLE action."""
var_name = params.get("variable_name", "")
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
if not var_name:
names = list_state_variable_names(ctx._contract)
return (
f"State variables: {', '.join(names)}",
Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Listed state variables", partial=True),
)
sv = get_state_variable_by_name(ctx._contract, var_name)
if sv is None:
return (
f"Variable '{var_name}' not found.",
Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Unknown state variable", partial=True),
)
return (
f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="get_state_variable cost", partial=True),
)
def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_CALL_GRAPH action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
cg = ctx._contract.get("call_graph", {})
cg_str = "; ".join(f"{fn} → [{', '.join(callees)}]" for fn, callees in cg.items())
return (
f"Call graph: {cg_str}",
Reward(value=ActionType.GET_CALL_GRAPH.cost, reason="get_call_graph cost", partial=True),
)
def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle SUBMIT action for Task 1.
Expected params
---------------
function_name : str – name of the vulnerable function
vulnerability_type: str – short description of the vulnerability
"""
if ctx._done:
return (
"Only ONE submission is allowed.",
Reward(value=ActionType.RESUBMIT.cost,
reason="Second submit_function attempt",
partial=False),
)
fn_name = params.get("function_name", "").strip()
vuln_type = params.get("vulnerability_type", "").strip()
if not fn_name or not vuln_type:
return (
"submit_function requires both 'function_name' and "
"'vulnerability_type' in params.",
Reward(value=0.0, reason="Malformed submission", partial=False),
)
ctx._done = True
score = ctx._grader.grade(fn_name, vuln_type, ctx._step_count, ctx._cummulative_cost)
return (f"Correct Answer: {ctx._grader.get_canonical_answer}"), Reward(
value=score,
reason=f"submit_function score={score:.1f}",
partial=False,
)
def unknown_action(ctx: Any, qkey: str, params: Dict, action_type: str) -> Tuple[str, Reward]:
"""Fallback for unknown actions."""
return (
f"Unknown action type: {action_type}",
Reward(value=ActionType.UNKNOWN.cost, reason="Unknown action", partial=True),
)