VD10 commited on
Commit
60cec59
·
verified ·
1 Parent(s): 0fc6b71

Upload patchjudge/feature_extractor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. patchjudge/feature_extractor.py +444 -0
patchjudge/feature_extractor.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Feature extractor for PatchJudge.
2
+
3
+ Extracts structured features from code patches using:
4
+ - Unified diff parsing
5
+ - AST analysis (Python)
6
+ - Keyword/entity extraction for issue-patch alignment
7
+ - Code quality signal detection
8
+ """
9
+
10
+ import ast
11
+ import re
12
+ import textwrap
13
+ from collections import Counter
14
+ from typing import Optional
15
+
16
+ from patchjudge.models import PatchExample, PatchFeatures
17
+
18
+
19
+ class FeatureExtractor:
20
+ """Extracts structured features from a patch for LLM evaluation."""
21
+
22
+ # Core/infrastructure files that are risky to modify
23
+ CORE_FILE_PATTERNS = [
24
+ r'__init__\.py$',
25
+ r'settings\.py$',
26
+ r'conf\.py$',
27
+ r'config\.py$',
28
+ r'setup\.py$',
29
+ r'setup\.cfg$',
30
+ r'manage\.py$',
31
+ r'urls\.py$',
32
+ r'wsgi\.py$',
33
+ r'asgi\.py$',
34
+ r'migrations/',
35
+ r'base\.py$',
36
+ ]
37
+
38
+ # Patterns that indicate hardcoded values
39
+ HARDCODE_PATTERNS = [
40
+ r'return\s+["\'].*["\']', # return "hardcoded string"
41
+ r'return\s+\d+\b', # return 42
42
+ r'==\s*["\'][^"\']{20,}', # comparing with long hardcoded string
43
+ r'if\s+.*==\s*\d{3,}', # comparing with specific large numbers
44
+ ]
45
+
46
+ # Debug statement patterns
47
+ DEBUG_PATTERNS = [
48
+ r'\bprint\s*\(',
49
+ r'\bpdb\b',
50
+ r'\bbreakpoint\s*\(',
51
+ r'\bIPython\b',
52
+ r'\bipdb\b',
53
+ r'\bconsole\.log\b',
54
+ r'\bdebugger\b',
55
+ ]
56
+
57
+ # TODO/FIXME patterns
58
+ TODO_PATTERNS = [
59
+ r'\bTODO\b',
60
+ r'\bFIXME\b',
61
+ r'\bHACK\b',
62
+ r'\bXXX\b',
63
+ r'\bTEMP\b',
64
+ r'\bWORKAROUND\b',
65
+ ]
66
+
67
+ def extract(self, example: PatchExample) -> PatchFeatures:
68
+ """Extract all features from a PatchExample."""
69
+ patch = example.agent_patch
70
+ features = PatchFeatures()
71
+
72
+ # --- Diff statistics ---
73
+ added_lines, removed_lines = self._parse_diff_lines(patch)
74
+ features.num_files_changed = self._count_files_changed(patch)
75
+ features.num_lines_added = len(added_lines)
76
+ features.num_lines_removed = len(removed_lines)
77
+ features.num_hunks = self._count_hunks(patch)
78
+
79
+ # --- Code structure (AST-based for Python) ---
80
+ if self._is_python_patch(patch):
81
+ features.added_functions = self._extract_added_functions(added_lines)
82
+ features.modified_functions = self._extract_modified_functions(
83
+ patch, example.repo_context
84
+ )
85
+ features.has_error_handling = self._check_error_handling(added_lines)
86
+ features.has_edge_case_handling = self._check_edge_cases(added_lines)
87
+ features.cyclomatic_complexity_delta = self._estimate_complexity_delta(
88
+ added_lines, removed_lines
89
+ )
90
+ features.nesting_depth_max = self._estimate_max_nesting(added_lines)
91
+
92
+ # --- Issue-patch alignment ---
93
+ issue_keywords = self._extract_issue_keywords(example.problem_statement)
94
+ patch_text = '\n'.join(added_lines + removed_lines)
95
+ addressed = self._match_keywords(issue_keywords, patch_text)
96
+ features.issue_keywords_addressed = addressed
97
+ features.issue_components_mentioned = self._extract_components(
98
+ example.problem_statement
99
+ )
100
+ if issue_keywords:
101
+ features.keyword_coverage_ratio = len(addressed) / len(issue_keywords)
102
+
103
+ # --- Code quality signals ---
104
+ added_text = '\n'.join(added_lines)
105
+ features.has_todos = self._check_patterns(added_text, self.TODO_PATTERNS)
106
+ features.has_hardcoded_values = self._check_patterns(added_text, self.HARDCODE_PATTERNS)
107
+ features.has_debug_statements = self._check_patterns(added_text, self.DEBUG_PATTERNS)
108
+ features.style_violations = self._check_style_basic(added_lines)
109
+ features.follows_project_style = len(features.style_violations) == 0
110
+
111
+ # --- Risk signals ---
112
+ changed_files = self._get_changed_files(patch)
113
+ features.modifies_core_files = self._check_core_files(changed_files)
114
+ features.change_scope = self._assess_scope(
115
+ features.num_files_changed,
116
+ features.num_lines_added + features.num_lines_removed
117
+ )
118
+ new_imports = self._extract_new_imports(added_lines)
119
+ features.has_imports_added = len(new_imports) > 0
120
+ features.new_imports = new_imports
121
+ features.touches_tests = any(
122
+ 'test' in f.lower() for f in changed_files
123
+ )
124
+
125
+ return features
126
+
127
+ # =========================================================================
128
+ # Diff parsing
129
+ # =========================================================================
130
+
131
+ def _parse_diff_lines(self, diff: str) -> tuple[list[str], list[str]]:
132
+ """Parse unified diff into added and removed lines (without +/- prefix)."""
133
+ added = []
134
+ removed = []
135
+ for line in diff.split('\n'):
136
+ if line.startswith('+') and not line.startswith('+++'):
137
+ added.append(line[1:])
138
+ elif line.startswith('-') and not line.startswith('---'):
139
+ removed.append(line[1:])
140
+ return added, removed
141
+
142
+ def _count_files_changed(self, diff: str) -> int:
143
+ """Count number of files changed in the diff."""
144
+ return len(set(
145
+ m.group(1)
146
+ for m in re.finditer(r'^diff --git a/.+ b/(.+)$', diff, re.MULTILINE)
147
+ ))
148
+
149
+ def _count_hunks(self, diff: str) -> int:
150
+ """Count number of hunks (@@ markers) in the diff."""
151
+ return len(re.findall(r'^@@\s', diff, re.MULTILINE))
152
+
153
+ def _get_changed_files(self, diff: str) -> list[str]:
154
+ """Get list of changed file paths."""
155
+ return list(set(
156
+ m.group(1)
157
+ for m in re.finditer(r'^diff --git a/.+ b/(.+)$', diff, re.MULTILINE)
158
+ ))
159
+
160
+ def _is_python_patch(self, diff: str) -> bool:
161
+ """Check if the patch modifies Python files."""
162
+ files = self._get_changed_files(diff)
163
+ return any(f.endswith('.py') for f in files)
164
+
165
+ # =========================================================================
166
+ # Code structure analysis
167
+ # =========================================================================
168
+
169
+ def _extract_added_functions(self, added_lines: list[str]) -> list[str]:
170
+ """Find function/method definitions in added lines."""
171
+ functions = []
172
+ for line in added_lines:
173
+ match = re.match(r'\s*def\s+(\w+)\s*\(', line)
174
+ if match:
175
+ functions.append(match.group(1))
176
+ # Also check for class definitions
177
+ match = re.match(r'\s*class\s+(\w+)', line)
178
+ if match:
179
+ functions.append(f"class:{match.group(1)}")
180
+ return functions
181
+
182
+ def _extract_modified_functions(
183
+ self, diff: str, repo_context: dict
184
+ ) -> list[str]:
185
+ """Find functions that were modified (existed before, changed now)."""
186
+ modified = []
187
+ # Parse hunk headers to find function context
188
+ for match in re.finditer(
189
+ r'^@@\s+.*\s+@@\s*(.*)$', diff, re.MULTILINE
190
+ ):
191
+ context = match.group(1).strip()
192
+ # Hunk headers often contain the function name
193
+ func_match = re.match(r'def\s+(\w+)', context)
194
+ if func_match:
195
+ modified.append(func_match.group(1))
196
+ class_match = re.match(r'class\s+(\w+)', context)
197
+ if class_match:
198
+ modified.append(f"class:{class_match.group(1)}")
199
+ return list(set(modified))
200
+
201
+ def _check_error_handling(self, added_lines: list[str]) -> bool:
202
+ """Check if added code includes error/exception handling."""
203
+ text = '\n'.join(added_lines)
204
+ patterns = [
205
+ r'\btry\s*:',
206
+ r'\bexcept\b',
207
+ r'\braise\b',
208
+ r'\bValueError\b',
209
+ r'\bTypeError\b',
210
+ r'\bKeyError\b',
211
+ r'\bAssertionError\b',
212
+ r'\bRuntimeError\b',
213
+ r'\bif\s+.*\bis\s+None\b',
214
+ r'\bif\s+not\b',
215
+ ]
216
+ return any(re.search(p, text) for p in patterns)
217
+
218
+ def _check_edge_cases(self, added_lines: list[str]) -> bool:
219
+ """Check if the patch handles edge cases."""
220
+ text = '\n'.join(added_lines)
221
+ patterns = [
222
+ r'\bif\s+len\(', # Length checks
223
+ r'\bif\s+not\s+\w+\s*:', # Empty checks
224
+ r'\bif\s+\w+\s+is\s+None', # None checks
225
+ r'\bif\s+.*<=?\s*0', # Zero/negative checks
226
+ r'\bif\s+isinstance\(', # Type checks
227
+ r'\bif\s+hasattr\(', # Attribute checks
228
+ r'\bor\s+\[\]', # Default empty list
229
+ r'\bor\s+\{\}', # Default empty dict
230
+ r'\bor\s+""', # Default empty string
231
+ r'\.get\(', # Dict .get() with default
232
+ ]
233
+ return sum(1 for p in patterns if re.search(p, text)) >= 2
234
+
235
+ def _estimate_complexity_delta(
236
+ self, added_lines: list[str], removed_lines: list[str]
237
+ ) -> int:
238
+ """Estimate change in cyclomatic complexity."""
239
+ complexity_keywords = [
240
+ 'if', 'elif', 'else', 'for', 'while', 'try', 'except',
241
+ 'and', 'or', 'with', 'assert'
242
+ ]
243
+
244
+ def count_complexity(lines):
245
+ count = 0
246
+ for line in lines:
247
+ stripped = line.strip()
248
+ for kw in complexity_keywords:
249
+ if re.search(rf'\b{kw}\b', stripped):
250
+ count += 1
251
+ break # Count each line only once
252
+ return count
253
+
254
+ return count_complexity(added_lines) - count_complexity(removed_lines)
255
+
256
+ def _estimate_max_nesting(self, added_lines: list[str]) -> int:
257
+ """Estimate maximum nesting depth in added code."""
258
+ max_depth = 0
259
+ for line in added_lines:
260
+ if line.strip():
261
+ # Count leading spaces (assume 4-space indent)
262
+ stripped = line.lstrip()
263
+ indent = len(line) - len(stripped)
264
+ depth = indent // 4
265
+ max_depth = max(max_depth, depth)
266
+ return max_depth
267
+
268
+ # =========================================================================
269
+ # Issue-patch alignment
270
+ # =========================================================================
271
+
272
+ def _extract_issue_keywords(self, problem_statement: str) -> list[str]:
273
+ """Extract meaningful keywords from the issue description."""
274
+ # Remove code blocks
275
+ text = re.sub(r'```[\s\S]*?```', '', problem_statement)
276
+ text = re.sub(r'`[^`]+`', '', text)
277
+
278
+ # Remove URLs
279
+ text = re.sub(r'https?://\S+', '', text)
280
+
281
+ # Extract potential identifiers (CamelCase, snake_case, etc.)
282
+ identifiers = re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b', text) # CamelCase
283
+ identifiers += re.findall(r'\b\w+_\w+\b', text) # snake_case
284
+ identifiers += re.findall(r'\b[A-Z]{2,}\b', text) # CONSTANTS
285
+
286
+ # Extract error message keywords
287
+ errors = re.findall(r'\b\w*(?:Error|Exception|Warning|Failure)\b', text)
288
+
289
+ # Extract method/function names mentioned
290
+ methods = re.findall(r'\.(\w+)\(', text)
291
+ methods += re.findall(r'def\s+(\w+)', text)
292
+ methods += re.findall(r'class\s+(\w+)', text)
293
+
294
+ # Combine and deduplicate
295
+ keywords = list(set(identifiers + errors + methods))
296
+ # Filter out common words
297
+ stopwords = {
298
+ 'the', 'is', 'in', 'it', 'to', 'and', 'or', 'not', 'this',
299
+ 'that', 'with', 'for', 'are', 'was', 'has', 'have', 'had',
300
+ 'when', 'would', 'should', 'could', 'will', 'can', 'may',
301
+ 'one', 'two', 'use', 'used', 'using', 'see', 'also',
302
+ }
303
+ keywords = [k for k in keywords if k.lower() not in stopwords and len(k) > 2]
304
+
305
+ return keywords[:30] # Cap at 30 keywords
306
+
307
+ def _match_keywords(self, keywords: list[str], patch_text: str) -> list[str]:
308
+ """Check which issue keywords appear in the patch."""
309
+ patch_lower = patch_text.lower()
310
+ return [k for k in keywords if k.lower() in patch_lower]
311
+
312
+ def _extract_components(self, problem_statement: str) -> list[str]:
313
+ """Extract software component names from the issue."""
314
+ # Look for file paths
315
+ files = re.findall(r'[\w/]+\.py\b', problem_statement)
316
+
317
+ # Look for module/class references
318
+ modules = re.findall(r'(?:from|import)\s+([\w.]+)', problem_statement)
319
+
320
+ # Look for class.method patterns
321
+ class_methods = re.findall(r'(\w+\.\w+)\(', problem_statement)
322
+
323
+ return list(set(files + modules + class_methods))[:20]
324
+
325
+ # =========================================================================
326
+ # Code quality signals
327
+ # =========================================================================
328
+
329
+ def _check_patterns(self, text: str, patterns: list[str]) -> bool:
330
+ """Check if any of the patterns match in the text."""
331
+ return any(re.search(p, text, re.IGNORECASE) for p in patterns)
332
+
333
+ def _check_style_basic(self, added_lines: list[str]) -> list[str]:
334
+ """Basic style checks without external tools."""
335
+ violations = []
336
+
337
+ for i, line in enumerate(added_lines):
338
+ # Line too long (PEP 8: 79 chars, flexible to 100)
339
+ if len(line) > 120:
340
+ violations.append(f"line_too_long:{i}")
341
+
342
+ # Trailing whitespace
343
+ if line != line.rstrip():
344
+ violations.append(f"trailing_whitespace:{i}")
345
+
346
+ # Mixed tabs and spaces
347
+ if '\t' in line and ' ' in line:
348
+ violations.append(f"mixed_indentation:{i}")
349
+
350
+ # Multiple statements on one line (except for comprehensions)
351
+ if ';' in line and 'for' not in line and 'import' not in line:
352
+ violations.append(f"multiple_statements:{i}")
353
+
354
+ # Deduplicate by type
355
+ types_seen = set()
356
+ unique = []
357
+ for v in violations:
358
+ vtype = v.split(':')[0]
359
+ if vtype not in types_seen:
360
+ types_seen.add(vtype)
361
+ unique.append(v)
362
+
363
+ return unique
364
+
365
+ def _check_core_files(self, changed_files: list[str]) -> bool:
366
+ """Check if any changed files match core/infrastructure patterns."""
367
+ for f in changed_files:
368
+ for pattern in self.CORE_FILE_PATTERNS:
369
+ if re.search(pattern, f):
370
+ return True
371
+ return False
372
+
373
+ def _assess_scope(self, num_files: int, total_lines: int) -> str:
374
+ """Assess the scope of changes."""
375
+ if num_files <= 1 and total_lines <= 20:
376
+ return "minimal"
377
+ elif num_files <= 3 and total_lines <= 100:
378
+ return "moderate"
379
+ else:
380
+ return "extensive"
381
+
382
+ def _extract_new_imports(self, added_lines: list[str]) -> list[str]:
383
+ """Extract newly added import statements."""
384
+ imports = []
385
+ for line in added_lines:
386
+ stripped = line.strip()
387
+ if stripped.startswith('import ') or stripped.startswith('from '):
388
+ imports.append(stripped)
389
+ return imports
390
+
391
+
392
+ def extract_features_batch(
393
+ examples: list[PatchExample],
394
+ show_progress: bool = True,
395
+ ) -> list[tuple[PatchExample, PatchFeatures]]:
396
+ """Extract features for a batch of examples."""
397
+ extractor = FeatureExtractor()
398
+ results = []
399
+
400
+ for i, ex in enumerate(examples):
401
+ if show_progress and (i + 1) % 50 == 0:
402
+ print(f" Extracted features for {i+1}/{len(examples)} examples")
403
+
404
+ features = extractor.extract(ex)
405
+ results.append((ex, features))
406
+
407
+ if show_progress:
408
+ print(f" Done: {len(results)} examples processed")
409
+
410
+ return results
411
+
412
+
413
+ if __name__ == "__main__":
414
+ import json
415
+ from patchjudge.data_loader import SWEBenchLoader
416
+
417
+ loader = SWEBenchLoader()
418
+ examples = loader.build_dataset(sources=["coderforge"])
419
+
420
+ extractor = FeatureExtractor()
421
+
422
+ # Extract features for first 5 examples
423
+ for ex in examples[:5]:
424
+ features = extractor.extract(ex)
425
+ print(f"\n{'='*60}")
426
+ print(f"Instance: {ex.instance_id}")
427
+ print(f"Agent: {ex.agent_name}")
428
+ print(f"Test passed: {ex.test_passed}")
429
+ print(f"Files changed: {features.num_files_changed}")
430
+ print(f"Lines +{features.num_lines_added}/-{features.num_lines_removed}")
431
+ print(f"Hunks: {features.num_hunks}")
432
+ print(f"Scope: {features.change_scope}")
433
+ print(f"Error handling: {features.has_error_handling}")
434
+ print(f"Edge cases: {features.has_edge_case_handling}")
435
+ print(f"TODOs: {features.has_todos}")
436
+ print(f"Debug stmts: {features.has_debug_statements}")
437
+ print(f"Hardcoded: {features.has_hardcoded_values}")
438
+ print(f"Core files: {features.modifies_core_files}")
439
+ print(f"New imports: {features.new_imports}")
440
+ print(f"Issue keywords addressed: {features.issue_keywords_addressed[:5]}")
441
+ print(f"Keyword coverage: {features.keyword_coverage_ratio:.2f}")
442
+ print(f"Style violations: {features.style_violations}")
443
+ print(f"Added functions: {features.added_functions}")
444
+ print(f"Modified functions: {features.modified_functions}")