|
|
| import ast
|
| import random
|
| from dataclasses import dataclass, field
|
| from typing import Tuple, Optional, List, Dict
|
|
|
|
|
|
|
|
|
| class ASTBugInjector(ast.NodeTransformer):
|
| def __init__(self, bug_type: str):
|
| super().__init__()
|
| self.bug_type = bug_type
|
| self.modified = False
|
|
|
|
|
| def visit_If(self, node: ast.If):
|
|
|
| if self.bug_type == "null_check" and not self.modified:
|
| if node.body and len(node.body) == 1:
|
| self.modified = True
|
| return node.body[0]
|
|
|
| if self.bug_type == "division_by_zero_empty" and not self.modified:
|
|
|
| if (isinstance(node.test, ast.UnaryOp) and
|
| isinstance(node.test.op, ast.Not) and
|
| isinstance(node.test.operand, ast.Name)):
|
| self.modified = True
|
| return None
|
| return self.generic_visit(node)
|
|
|
| def visit_Name(self, node: ast.Name):
|
| if self.bug_type == "simple_typo" and not self.modified:
|
| if node.id == "users":
|
| self.modified = True
|
| return ast.Name(id="usres", ctx=node.ctx)
|
| return self.generic_visit(node)
|
|
|
| def visit_Subscript(self, node: ast.Subscript):
|
| if self.bug_type == "string_index" and not self.modified:
|
| if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
|
| old_val = node.slice.value.value
|
| if isinstance(old_val, int):
|
| self.modified = True
|
| node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
|
| return self.generic_visit(node)
|
|
|
| def visit_Call(self, node: ast.Call):
|
|
|
| if self.bug_type == "default_value" and not self.modified:
|
| if (isinstance(node.func, ast.Attribute) and
|
| node.func.attr == "get" and len(node.args) == 1):
|
| self.modified = True
|
| return ast.Subscript(
|
| value=node.func.value,
|
| slice=ast.Index(value=node.args[0]),
|
| ctx=node.ctx
|
| )
|
|
|
| if self.bug_type == "abs_usage" and not self.modified:
|
| if isinstance(node.func, ast.Name) and node.func.id == "abs":
|
| self.modified = True
|
| return node.args[0]
|
| return self.generic_visit(node)
|
|
|
| def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
|
| if self.bug_type == "empty_return" and not self.modified:
|
| self.modified = True
|
| node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
|
| return self.generic_visit(node)
|
|
|
|
|
| def visit_For(self, node: ast.For):
|
| if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
|
| if (isinstance(node.iter, ast.Call) and
|
| isinstance(node.iter.func, ast.Name) and
|
| node.iter.func.id == "range"):
|
| if self.bug_type == "off_by_one":
|
| new_iter = ast.Call(
|
| func=ast.Name(id='range', ctx=ast.Load()),
|
| args=[
|
| ast.Constant(value=1),
|
| ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
|
| ],
|
| keywords=[]
|
| )
|
| node.iter = new_iter
|
| self.modified = True
|
| elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
|
| new_iter = ast.Call(
|
| func=ast.Name(id='range', ctx=ast.Load()),
|
| args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
|
| keywords=[]
|
| )
|
| node.iter = new_iter
|
| self.modified = True
|
| return self.generic_visit(node)
|
|
|
| def visit_BinOp(self, node: ast.BinOp):
|
|
|
| if not self.modified:
|
| if self.bug_type in ("wrong_operator", "sign_error"):
|
| if isinstance(node.op, ast.Add):
|
| node.op = ast.Sub()
|
| self.modified = True
|
| elif isinstance(node.op, ast.Sub):
|
| node.op = ast.Add()
|
| self.modified = True
|
| elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
|
| node.op = ast.FloorDiv()
|
| self.modified = True
|
| return self.generic_visit(node)
|
|
|
| def visit_arguments(self, node: ast.arguments):
|
|
|
| if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
|
| self.modified = True
|
| node.args[0], node.args[1] = node.args[1], node.args[0]
|
| return self.generic_visit(node)
|
|
|
| def visit_Assign(self, node: ast.Assign):
|
|
|
| if self.bug_type == "uninitialised_var" and not self.modified:
|
| self.modified = True
|
| return ast.Pass()
|
| return self.generic_visit(node)
|
|
|
|
|
|
|
|
|
| BUG_DB = {
|
| "easy": {
|
| "null_check": {"type": "ast", "bug_type": "null_check"},
|
| "simple_typo": {"type": "ast", "bug_type": "simple_typo"},
|
| "string_index": {"type": "ast", "bug_type": "string_index"},
|
| "default_value": {"type": "ast", "bug_type": "default_value"},
|
| "empty_return": {"type": "ast", "bug_type": "empty_return"},
|
| },
|
| "medium": {
|
| "off_by_one": {"type": "ast", "bug_type": "off_by_one"},
|
| "loop_skip": {"type": "ast", "bug_type": "loop_skip"},
|
| "sign_error": {"type": "ast", "bug_type": "sign_error"},
|
| "swap_args": {"type": "ast", "bug_type": "swap_args"},
|
| "uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
|
| },
|
| "hard": {
|
| "division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
|
| "division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"},
|
| "float_precision": {"type": "ast", "bug_type": "float_precision"},
|
| "abs_usage": {"type": "ast", "bug_type": "abs_usage"},
|
| "round_error": {"type": "ast", "bug_type": "round_error"},
|
| },
|
| "harder": {
|
| "missing_lock": {
|
| "type": "template",
|
| "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
|
| "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
|
| },
|
| "double_lock": {
|
| "type": "template",
|
| "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
|
| "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
|
| },
|
| "global_nonatomic": {
|
| "type": "template",
|
| "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
|
| "oracle": "count = 0\ndef add():\n global count\n count += 1",
|
| },
|
| "thread_safe_list": {
|
| "type": "template",
|
| "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
|
| "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
|
| },
|
| "volatile_read": {
|
| "type": "template",
|
| "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
|
| "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
|
| },
|
| },
|
| "hardest": {
|
| "deadlock_order": {
|
| "type": "template",
|
| "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
|
| "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
|
| },
|
| "nested_lock_timeout": {
|
| "type": "template",
|
| "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
|
| "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
|
| },
|
| "fork_join": {
|
| "type": "template",
|
| "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
|
| "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
|
| },
|
| "mutex_release": {
|
| "type": "template",
|
| "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
|
| "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
|
| },
|
| "race_on_init": {
|
| "type": "template",
|
| "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
|
| "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
|
| },
|
| },
|
| }
|
|
|
|
|
|
|
|
|
| TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
|
|
|
| TEMPLATE_BUGS = {}
|
| for level, bugs in BUG_DB.items():
|
| for bug_id, bug in bugs.items():
|
| if bug["type"] == "template":
|
| TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class RedTeam:
|
| task: str
|
| seed: Optional[int] = 42
|
| noise_prob: float = 0.2
|
| _random: random.Random = field(init=False)
|
|
|
| def __post_init__(self):
|
| self._random = random.Random(self.seed)
|
|
|
| def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
|
| """
|
| Returns: (buggy_code, bug_type, description, oracle_fix)
|
| Selects a bug appropriate for the task difficulty.
|
| """
|
| bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
|
| bug_type = self._random.choice(bug_list)
|
|
|
|
|
| if bug_type in TEMPLATE_BUGS:
|
| buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
|
| description = f"Template bug: {bug_type}"
|
| if self._random.random() < self.noise_prob:
|
| buggy_code += "\n# TODO: refactor later"
|
| return buggy_code, bug_type, description, oracle_code
|
|
|
|
|
| try:
|
| tree = ast.parse(original_code)
|
| except SyntaxError:
|
| return original_code, "parse_error", "Syntax error in original code", original_code
|
|
|
| injector = ASTBugInjector(bug_type)
|
| modified_tree = injector.visit(tree)
|
| ast.fix_missing_locations(modified_tree)
|
|
|
| if injector.modified:
|
| buggy_code = ast.unparse(modified_tree)
|
| oracle_fix = original_code
|
| description = f"AST bug: {bug_type}"
|
| else:
|
| buggy_code = original_code
|
| oracle_fix = original_code
|
| bug_type = "no_op"
|
| description = "No suitable code structure found for injection"
|
|
|
| if self._random.random() < self.noise_prob:
|
| buggy_code += "\n# TODO: refactor later"
|
|
|
| return buggy_code, bug_type, description, oracle_fix
|
|
|