Nitishkumar-ai commited on
Commit
8f4e44a
·
1 Parent(s): 6398066

Feat (Phase 1 & 2): Extract scanner module and add CLI interface

Browse files
commitguard_env/cli.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import subprocess
4
+ import sys
5
+ from dataclasses import asdict
6
+ from pathlib import Path
7
+
8
+ from .scanner import CommitGuardScanner
9
+
10
+
11
+ def cmd_scan(args):
12
+ diff_text = ""
13
+ if getattr(args, "diff", None):
14
+ diff_text = Path(args.diff).read_text(encoding="utf-8")
15
+ elif getattr(args, "staged", False):
16
+ diff_text = subprocess.check_output(["git", "diff", "--staged"], text=True)
17
+ elif getattr(args, "commit", None):
18
+ diff_text = subprocess.check_output(["git", "show", args.commit], text=True)
19
+ elif getattr(args, "pr", None):
20
+ diff_text = subprocess.check_output(["gh", "pr", "diff", args.pr], text=True)
21
+ else:
22
+ print("Must specify one of --diff, --staged, --commit, or --pr")
23
+ sys.exit(1)
24
+
25
+ if not diff_text.strip():
26
+ print("No diff found to scan.")
27
+ sys.exit(0)
28
+
29
+ print(f"Loading model ({args.model})...", file=sys.stderr)
30
+ scanner = CommitGuardScanner(model_path=args.model, is_lora=args.is_lora, base_model=args.base_model)
31
+
32
+ print(f"Scanning diff ({len(diff_text)} chars)...", file=sys.stderr)
33
+ result = scanner.scan(diff_text)
34
+
35
+ if args.format == "json":
36
+ print(json.dumps(asdict(result), indent=2))
37
+ elif args.format == "text":
38
+ status = "VULNERABLE ⚠️" if result.is_vulnerable else "SAFE ✅"
39
+ print(f"\nVerdict: {status}")
40
+ if result.is_vulnerable:
41
+ print(f"CWE: {result.cwe}")
42
+ print(f"Exploit Sketch:\n {result.exploit_sketch}")
43
+ if result.parse_error:
44
+ print(f"\nParser Warning: {result.parse_error}")
45
+ elif args.format == "sarif":
46
+ # Minimal SARIF output stub
47
+ print("SARIF format not fully implemented yet.", file=sys.stderr)
48
+ print(json.dumps(asdict(result)))
49
+
50
+ if args.fail_on_vulnerable and result.is_vulnerable:
51
+ sys.exit(1)
52
+
53
+
54
+ def cmd_server(args):
55
+ from .server import main as server_main
56
+ server_main()
57
+
58
+
59
+ def cmd_eval(args):
60
+ # This is a bit hacky to reuse the script without modifying sys.path everywhere
61
+ # A cleaner approach would be moving evaluate.py into commitguard_env
62
+ REPO_ROOT = Path(__file__).resolve().parent.parent
63
+ eval_script = REPO_ROOT / "scripts" / "evaluate.py"
64
+
65
+ cmd = [sys.executable, str(eval_script)]
66
+ cmd.extend(args.eval_args)
67
+ subprocess.run(cmd, check=True)
68
+
69
+
70
+ def main():
71
+ parser = argparse.ArgumentParser(description="CommitGuard AI-paced security review")
72
+ subparsers = parser.add_subparsers(dest="command", required=True)
73
+
74
+ # 'scan' subcommand
75
+ scan_parser = subparsers.add_parser("scan", help="Scan a code diff for vulnerabilities")
76
+
77
+ source_group = scan_parser.add_mutually_exclusive_group(required=True)
78
+ source_group.add_argument("--diff", type=str, help="Path to a diff file")
79
+ source_group.add_argument("--staged", action="store_true", help="Scan git staged changes")
80
+ source_group.add_argument("--commit", type=str, help="Scan a specific git commit (e.g., HEAD)")
81
+ source_group.add_argument("--pr", type=str, help="Scan a GitHub PR URL or ID (requires gh cli)")
82
+
83
+ scan_parser.add_argument("--model", type=str, default="inmodel-labs/commitguard-llama-3b", help="Model path or HF ID")
84
+ scan_parser.add_argument("--base-model", type=str, default=None, help="Base model if using LoRA")
85
+ scan_parser.add_argument("--is-lora", action="store_true", help="Whether the model is a LoRA adapter")
86
+ scan_parser.add_argument("--format", choices=["text", "json", "sarif"], default="text", help="Output format")
87
+ scan_parser.add_argument("--fail-on-vulnerable", action="store_true", help="Exit with code 1 if vulnerable")
88
+
89
+ # 'server' subcommand
90
+ server_parser = subparsers.add_parser("server", help="Start the OpenEnv environment server")
91
+ # server_main takes PORT from environment
92
+
93
+ # 'eval' subcommand
94
+ eval_parser = subparsers.add_parser("eval", help="Run the evaluation harness")
95
+ eval_parser.add_argument("eval_args", nargs=argparse.REMAINDER, help="Arguments passed to evaluate.py")
96
+
97
+ args = parser.parse_args()
98
+
99
+ if args.command == "scan":
100
+ cmd_scan(args)
101
+ elif args.command == "server":
102
+ cmd_server(args)
103
+ elif args.command == "eval":
104
+ cmd_eval(args)
105
+
106
+ if __name__ == "__main__":
107
+ main()
commitguard_env/inference.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ # Add project root to path for imports to find agent_prompt if run directly
8
+ REPO_ROOT = Path(__file__).resolve().parent.parent
9
+ sys.path.insert(0, str(REPO_ROOT))
10
+
11
+ try:
12
+ from agent_prompt import SYSTEM_PROMPT
13
+ except ImportError:
14
+ # Fallback if not found
15
+ SYSTEM_PROMPT = """You are a senior security researcher and pentester. Your task is to analyze code commits (diffs) to determine if they introduce exploitable vulnerabilities.
16
+
17
+ You operate in a multi-step environment (up to 5 steps). You can request more context, analyze your thoughts, or issue a final verdict.
18
+
19
+ ### Action Format
20
+ You MUST respond with exactly ONE action per turn, wrapped in XML tags:
21
+
22
+ 1. **Request Context:** Use this if you need to see the full content of a file listed in 'available_files'.
23
+ <action>
24
+ <action_type>request_context</action_type>
25
+ <file_path>filename.c</file_path>
26
+ </action>
27
+
28
+ 2. **Analyze:** Use this for your internal Chain-of-Thought reasoning. Be detailed.
29
+ <action>
30
+ <action_type>analyze</action_type>
31
+ <reasoning>Your detailed step-by-step security analysis here...</reasoning>
32
+ </action>
33
+
34
+ 3. **Verdict:** Use this to terminate the episode with your final judgment.
35
+ <action>
36
+ <action_type>verdict</action_type>
37
+ <is_vulnerable>true/false</is_vulnerable>
38
+ <vuln_type>CWE-XX (e.g., CWE-89)</vuln_type>
39
+ <exploit_sketch>Brief description of how this could be exploited...</exploit_sketch>
40
+ </action>
41
+
42
+ ### Rules & Constraints
43
+ - If the code is safe, set is_vulnerable to false and vuln_type to NONE.
44
+ - Be specific in exploit_sketch: name the attack vector (e.g., buffer overflow via unchecked memcpy).
45
+ - Common CWE types: CWE-89 (SQLi), CWE-79 (XSS), CWE-78 (Command Inj), CWE-22 (Path Traversal), CWE-119 (Buffer Overflow), CWE-476 (Null Dereference), CWE-190 (Integer Overflow).
46
+ - You have a maximum of 5 steps per episode.
47
+ - Context requests have a small cost; be efficient.
48
+ - Verifiable rewards (RLVR) are based on the accuracy of your final verdict and the presence of correct exploit keywords.
49
+ """
50
+
51
+
52
+ def format_prompt(diff: str, available_files: list[str] = None) -> str:
53
+ """Format the diff into the expected model prompt."""
54
+ files_str = ", ".join(available_files) if available_files else "None"
55
+
56
+ user_prompt = f"""### Input Diff
57
+ {diff}
58
+
59
+ ### Environment Info
60
+ - Available Files: {files_str}
61
+ - Current Step: 0/5
62
+
63
+ Please provide your next action in XML format:"""
64
+
65
+ return (
66
+ f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
67
+ f"{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
68
+ f"{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
69
+ )
70
+
71
+ def load_model(model_path: str, is_lora: bool = False, base_model: str = None) -> tuple[Any, Any]:
72
+ """
73
+ Load the LLM and tokenizer for inference.
74
+ """
75
+ import torch
76
+
77
+ if is_lora:
78
+ if not base_model:
79
+ raise ValueError("base_model is required if is_lora=True")
80
+ from unsloth import FastLanguageModel
81
+ from peft import PeftModel
82
+
83
+ model, tokenizer = FastLanguageModel.from_pretrained(
84
+ model_name=base_model,
85
+ max_seq_length=2048,
86
+ load_in_4bit=True,
87
+ )
88
+ model = PeftModel.from_pretrained(model, model_path)
89
+ FastLanguageModel.for_inference(model)
90
+ else:
91
+ from transformers import AutoModelForCausalLM, AutoTokenizer
92
+
93
+ device_map = "auto" if torch.cuda.is_available() else None
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ model_path,
96
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
97
+ device_map=device_map
98
+ )
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
100
+
101
+ return model, tokenizer
102
+
103
+ def generate(model: Any, tokenizer: Any, prompt: str, max_new_tokens: int = 256) -> str:
104
+ import torch
105
+ device = "cuda" if torch.cuda.is_available() else "cpu"
106
+
107
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
108
+
109
+ with torch.no_grad():
110
+ output = model.generate(
111
+ **inputs,
112
+ max_new_tokens=max_new_tokens,
113
+ temperature=0.1,
114
+ do_sample=False,
115
+ )
116
+
117
+ response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
118
+ return response
commitguard_env/models.py CHANGED
@@ -59,3 +59,12 @@ class DevignSample:
59
  target_file: Optional[str] = None
60
  files: Optional[dict[str, str]] = None
61
 
 
 
 
 
 
 
 
 
 
 
59
  target_file: Optional[str] = None
60
  files: Optional[dict[str, str]] = None
61
 
62
+
63
+ @dataclass(frozen=True, slots=True)
64
+ class ScanResult:
65
+ is_vulnerable: bool
66
+ cwe: Optional[str]
67
+ exploit_sketch: Optional[str]
68
+ raw_response: str
69
+ parse_error: Optional[str] = None
70
+
commitguard_env/scanner.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from .inference import format_prompt, generate, load_model
6
+ from .models import ScanResult
7
+ from .parse_action import parse_action
8
+
9
+
10
+ class CommitGuardScanner:
11
+ """
12
+ Scanner for CommitGuard vulnerabilities.
13
+ Keeps the model in memory to allow fast scanning of multiple diffs.
14
+ """
15
+
16
+ def __init__(self, model_path: str = "inmodel-labs/commitguard-llama-3b", is_lora: bool = False, base_model: str = None) -> None:
17
+ self.model_path = model_path
18
+ self.is_lora = is_lora
19
+ self.base_model = base_model
20
+ self.model: Any = None
21
+ self.tokenizer: Any = None
22
+
23
+ def load(self) -> None:
24
+ """Load the model and tokenizer into memory."""
25
+ if self.model is None or self.tokenizer is None:
26
+ self.model, self.tokenizer = load_model(self.model_path, self.is_lora, self.base_model)
27
+
28
+ def scan(self, diff: str, available_files: list[str] = None) -> ScanResult:
29
+ """
30
+ Scan a given diff for vulnerabilities.
31
+ """
32
+ self.load()
33
+
34
+ prompt = format_prompt(diff, available_files)
35
+ response = generate(self.model, self.tokenizer, prompt)
36
+ action = parse_action(response)
37
+
38
+ # Map to ScanResult
39
+ return ScanResult(
40
+ is_vulnerable=action.is_vulnerable if action.is_vulnerable is not None else False,
41
+ cwe=action.vuln_type,
42
+ exploit_sketch=action.exploit_sketch,
43
+ raw_response=response,
44
+ parse_error=action.parse_error
45
+ )
46
+
47
+
48
+ def scan(diff: str, model_path: str = "inmodel-labs/commitguard-llama-3b", is_lora: bool = False, base_model: str = None) -> ScanResult:
49
+ """
50
+ Convenience method to scan a single diff. Loads the model, scans, and returns the result.
51
+ If scanning multiple diffs, prefer instantiating CommitGuardScanner directly to avoid reloading the model.
52
+ """
53
+ scanner = CommitGuardScanner(model_path=model_path, is_lora=is_lora, base_model=base_model)
54
+ return scanner.scan(diff)
pyproject.toml CHANGED
@@ -33,6 +33,7 @@ train = [
33
  ]
34
 
35
  [project.scripts]
 
36
  server = "commitguard_env.server:main"
37
 
38
  [tool.setuptools]
 
33
  ]
34
 
35
  [project.scripts]
36
+ commitguard = "commitguard_env.cli:main"
37
  server = "commitguard_env.server:main"
38
 
39
  [tool.setuptools]