| """Tests for the role helpers (drift generator + repair agent).""" |
| import json |
|
|
| from forgeenv.env.diff_utils import apply_unified_diff |
| from forgeenv.primitives.breakage_primitives import RenameApiCall |
| from forgeenv.roles.drift_generator import ( |
| BaselineDriftGenerator, |
| parse_drift_output, |
| parse_drift_to_primitive, |
| ) |
| from forgeenv.roles.prompts import ( |
| DRIFT_GENERATOR_SYSTEM_PROMPT, |
| REPAIR_AGENT_SYSTEM_PROMPT, |
| render_drift_generator_prompt, |
| render_repair_agent_prompt, |
| ) |
| from forgeenv.roles.repair_agent import ( |
| BaselineRepairAgent, |
| extract_diff, |
| looks_like_diff, |
| ) |
|
|
|
|
| |
| def test_prompts_are_nonempty(): |
| assert "Drift Generator" in DRIFT_GENERATOR_SYSTEM_PROMPT |
| assert "Repair Agent" in REPAIR_AGENT_SYSTEM_PROMPT |
|
|
|
|
| def test_render_drift_generator_prompt_includes_inputs(): |
| text = render_drift_generator_prompt( |
| "import torch", "RenameApiCall", {"transformers": "4.40.0"} |
| ) |
| assert "RenameApiCall" in text and "transformers=4.40.0" in text and "import torch" in text |
|
|
|
|
| def test_render_repair_agent_prompt_includes_error_trace(): |
| text = render_repair_agent_prompt( |
| "broken", "AttributeError: foo", {"transformers": "4.50.0"} |
| ) |
| assert "AttributeError" in text and "transformers=4.50.0" in text |
|
|
|
|
| |
| def test_parse_drift_output_handles_fences(): |
| text = "```json\n{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"a\", \"new_name\": \"b\"}}\n```" |
| parsed = parse_drift_output(text) |
| assert parsed is not None and parsed["primitive_type"] == "RenameApiCall" |
|
|
|
|
| def test_parse_drift_output_handles_prose(): |
| text = ( |
| "Here is my breakage idea, it's a rename:\n" |
| "{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"x\", \"new_name\": \"y\"}}\n" |
| "Hope this works!" |
| ) |
| parsed = parse_drift_output(text) |
| assert parsed["primitive_type"] == "RenameApiCall" |
|
|
|
|
| def test_parse_drift_output_returns_none_on_garbage(): |
| assert parse_drift_output("no JSON here at all") is None |
| assert parse_drift_output("") is None |
|
|
|
|
| def test_parse_drift_to_primitive_validates(): |
| text = '{"primitive_type": "DeprecateImport", "params": {"old_module": "a", "new_module": "b"}}' |
| primitive = parse_drift_to_primitive(text) |
| assert primitive is not None and primitive.name == "DeprecateImport" |
|
|
|
|
| def test_parse_drift_to_primitive_unknown_type(): |
| text = '{"primitive_type": "NonExistent", "params": {}}' |
| assert parse_drift_to_primitive(text) is None |
|
|
|
|
| def test_baseline_drift_generator_produces_valid_spec(): |
| gen = BaselineDriftGenerator(seed=0) |
| script = """from transformers import Trainer |
| trainer = Trainer() |
| trainer.train() |
| """ |
| spec = gen.propose(target_category="RenameApiCall", script=script) |
| assert spec["primitive_type"] in { |
| "RenameApiCall", "DeprecateImport", "ChangeArgumentSignature", |
| "ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior", |
| "RemoveDeprecatedMethod", "ChangeReturnType", |
| } |
| assert isinstance(spec["params"], dict) |
|
|
|
|
| def test_baseline_drift_generator_spec_actually_breaks_script(): |
| gen = BaselineDriftGenerator(seed=42) |
| script = """from transformers import Trainer |
| trainer = Trainer() |
| trainer.train() |
| """ |
| spec = gen.propose(target_category="RenameApiCall", script=script) |
| primitive = parse_drift_to_primitive(json.dumps(spec)) |
| broken = primitive.apply(script) |
| |
| if spec["primitive_type"] == "RenameApiCall" and spec["params"].get("old_name") in script: |
| assert broken != script |
|
|
|
|
| |
| def test_extract_diff_strips_fences(): |
| text = "Here's my fix:\n```diff\n--- a/x\n+++ b/x\n@@\n-foo\n+bar\n```\n" |
| diff = extract_diff(text) |
| assert diff.startswith("---") and "foo" in diff and "bar" in diff |
|
|
|
|
| def test_extract_diff_strips_chain_of_thought(): |
| text = ( |
| "Let me think... the error is X, so I should rename Y to Z.\n" |
| "Here is the diff:\n" |
| "--- a/train.py\n+++ b/train.py\n@@ -1 +1 @@\n-import torch\n+import torch.legacy\n" |
| ) |
| diff = extract_diff(text) |
| assert diff.startswith("---") |
| assert "Let me think" not in diff |
|
|
|
|
| def test_looks_like_diff_positive(): |
| diff = "--- a/x\n+++ b/x\n@@ -1 +1 @@\n-foo\n+bar\n" |
| assert looks_like_diff(diff) |
|
|
|
|
| def test_looks_like_diff_negative(): |
| assert not looks_like_diff("just some prose without any diff structure") |
|
|
|
|
| def test_baseline_repair_agent_oracle_path(): |
| agent = BaselineRepairAgent() |
| original = "import torch\nprint('hi')\n" |
| broken = "import torch.legacy\nprint('hi')\n" |
| diff = agent.repair(broken, breakage_spec=None, original_script=original) |
| assert diff and "torch.legacy" in diff |
| repaired = apply_unified_diff(broken, diff) |
| assert repaired == original |
|
|
|
|
| def test_baseline_repair_agent_inverts_breakage_spec(): |
| agent = BaselineRepairAgent() |
| original = "from transformers import Trainer\ntrainer.train()\n" |
| breakage = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") |
| broken = breakage.apply(original) |
| spec = breakage.to_spec() |
|
|
| diff = agent.repair(broken, breakage_spec=spec) |
| assert diff |
| repaired = apply_unified_diff(broken, diff) |
| assert "trainer.train()" in repaired |
|
|