| from typing import Dict, Any, Tuple, Optional, List, Deque |
| from collections import deque |
| import random |
| from environment.models import ( |
| ReviewAction, |
| ReviewState, |
| Observation |
| ) |
| from environment.tasks import TaskDefinitions |
| from environment.graders import TaskGrader, RewardCalculator |
|
|
|
|
| class CodeReviewEnv: |
|
|
| def __init__(self): |
| self._state: Optional[ReviewState] = None |
| self.grader: Optional[TaskGrader] = None |
| self.reward_calculator = RewardCalculator() |
| self.max_steps = 50 |
| self.current_task_id: Optional[str] = None |
| self._action_history: Deque[Dict[str, Any]] = deque(maxlen=5) |
| self._seed: Optional[int] = None |
|
|
| def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None) -> Dict[str, Any]: |
| if seed is not None: |
| self._seed = seed |
| random.seed(seed) |
|
|
| if task_id is None: |
| task_id = "bug_detection_easy_1" |
|
|
| self.current_task_id = task_id |
| task_data = TaskDefinitions.get_task(task_id) |
|
|
| code_context = TaskDefinitions.create_code_context(task_data) |
| task_metadata = TaskDefinitions.create_task_metadata(task_data) |
|
|
| self._state = ReviewState( |
| code_context=code_context, |
| task_metadata=task_metadata, |
| comments_made=[], |
| suggestions_made=[], |
| current_step=0, |
| is_complete=False, |
| final_decision=None, |
| last_action_valid=True, |
| last_error=None |
| ) |
|
|
| self.grader = TaskGrader(task_metadata.expected_issues) |
| self.reward_calculator.reset() |
| self._action_history.clear() |
|
|
| return self._get_observation() |
|
|
| def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]: |
| if self._state is None: |
| return {}, -0.1, True, {"error": "Environment not initialized. Call reset() first."} |
|
|
| if self._state.is_complete: |
| return self._get_observation(), 0.0, True, {"error": "Episode already complete"} |
|
|
| if self.grader is None: |
| return self._get_observation(), -0.1, True, {"error": "Environment not initialized. Call reset() first."} |
| grader = self.grader |
|
|
| try: |
| review_action = ReviewAction(**action) |
| except Exception as e: |
| self._state.last_action_valid = False |
| self._state.last_error = str(e) |
| return self._get_observation(), -0.1, False, {"error": str(e), "last_action_valid": False} |
|
|
| existing_comment_keys = { |
| (c.line_number, c.content.strip().lower()) for c in self._state.comments_made |
| } |
|
|
| self._state.current_step += 1 |
| self._process_action(review_action) |
|
|
| duplicate_comment_count = 0 |
| if review_action.action_type.value == "add_comment": |
| for c in review_action.comments: |
| key = (c.line_number, c.content.strip().lower()) |
| if key in existing_comment_keys: |
| duplicate_comment_count += 1 |
|
|
| self._action_history.append({ |
| "step": self._state.current_step, |
| "action_type": review_action.action_type.value, |
| "num_comments": len(review_action.comments), |
| "num_suggestions": len(review_action.suggestions), |
| "final_decision": review_action.final_decision, |
| }) |
|
|
| if review_action.action_type.value == "approve" and not review_action.final_decision: |
| review_action.final_decision = "approved" |
| elif review_action.action_type.value == "request_changes" and not review_action.final_decision: |
| review_action.final_decision = "changes_requested" |
|
|
| if review_action.final_decision and not self._state.is_complete: |
| self._state.is_complete = True |
| self._state.final_decision = review_action.final_decision |
|
|
| if self._state.current_step >= self.max_steps and not self._state.is_complete: |
| self._state.is_complete = True |
| if not self._state.final_decision: |
| self._state.final_decision = "changes_requested" |
|
|
| reward = self.reward_calculator.calculate_reward( |
| review_action, |
| self._state.comments_made, |
| self._state.suggestions_made, |
| self._state.final_decision or "changes_requested", |
| grader, |
| self._state.last_action_valid, |
| duplicate_comment_count=duplicate_comment_count, |
| steps_taken=self._state.current_step, |
| max_steps=self.max_steps, |
| ) |
|
|
| diagnostics = grader.get_diagnostics( |
| comments=self._state.comments_made, |
| suggestions=self._state.suggestions_made, |
| final_decision=self._state.final_decision or "changes_requested", |
| steps_taken=self._state.current_step, |
| max_steps=self.max_steps, |
| ) |
|
|
| info = { |
| "step": self._state.current_step, |
| "last_action_valid": self._state.last_action_valid, |
| "error": self._state.last_error, |
| "task_score": self.get_task_score(), |
| "diagnostics": diagnostics, |
| "valid_actions": self.valid_actions(), |
| } |
|
|
| return self._get_observation(), reward, self._state.is_complete, info |
|
|
| def _process_action(self, action: ReviewAction): |
| if self._state is None: |
| return |
|
|
| self._state.last_action_valid = True |
| self._state.last_error = None |
|
|
| if action.action_type.value == "add_comment": |
| for comment in action.comments: |
| if comment.line_number <= self._state.code_context.line_count: |
| self._state.comments_made.append(comment) |
| else: |
| self._state.last_action_valid = False |
| self._state.last_error = f"Line {comment.line_number} out of range" |
|
|
| elif action.action_type.value == "suggest_fix": |
| for suggestion in action.suggestions: |
| if suggestion.original_line <= self._state.code_context.line_count: |
| self._state.suggestions_made.append(suggestion) |
| else: |
| self._state.last_action_valid = False |
| self._state.last_error = f"Line {suggestion.original_line} out of range" |
|
|
| elif action.action_type.value == "mark_as_resolved": |
| if not self._state.comments_made: |
| self._state.last_action_valid = False |
| self._state.last_error = "No comments exist to mark as resolved" |
| return |
| for comment in action.comments: |
| for existing_comment in self._state.comments_made: |
| if existing_comment.line_number == comment.line_number: |
| existing_comment.resolved = True |
|
|
| def valid_actions(self) -> List[str]: |
| if self._state is None: |
| return [] |
|
|
| actions = ["add_comment", "approve", "request_changes"] |
|
|
| if self._state.comments_made: |
| actions.append("suggest_fix") |
| actions.append("mark_as_resolved") |
|
|
| return actions |
|
|
| def _get_observation(self) -> Dict[str, Any]: |
| if self._state is None: |
| return {} |
|
|
| obs = Observation( |
| code_diff=self._state.code_context.code_diff, |
| file_context=self._state.code_context.surrounding_code, |
| file_path=self._state.code_context.file_path, |
| language=self._state.code_context.language, |
| task_description=self._state.task_metadata.description, |
| task_difficulty=self._state.task_metadata.difficulty, |
| current_step=self._state.current_step, |
| max_steps=self.max_steps, |
| previous_comments=self._state.comments_made, |
| previous_suggestions=self._state.suggestions_made, |
| review_complete=self._state.is_complete, |
| final_decision_made=self._state.final_decision |
| ).model_dump() |
|
|
| obs["action_history"] = list(self._action_history) |
| obs["valid_actions"] = self.valid_actions() |
| obs["line_count"] = self._state.code_context.line_count |
|
|
| return obs |
|
|
| def render(self): |
| if self._state is None: |
| print("Environment not initialized.") |
| return |
|
|
| print("=" * 60) |
| print(f"FILE : {self._state.code_context.file_path}") |
| print(f"LANGUAGE : {self._state.code_context.language}") |
| print(f"STEP : {self._state.current_step}/{self.max_steps}") |
| print(f"DONE : {self._state.is_complete}") |
| print(f"DECISION : {self._state.final_decision or 'pending'}") |
| print(f"SCORE : {self.get_task_score():.3f}") |
| print("-" * 60) |
| print("CODE DIFF:") |
| for i, line in enumerate(self._state.code_context.code_diff.split("\n"), start=1): |
| print(f" {i}: {line}") |
| print("-" * 60) |
|
|
| if self._state.comments_made: |
| print(f"COMMENTS ({len(self._state.comments_made)}):") |
| for c in self._state.comments_made: |
| print(f" Line {c.line_number} [{c.severity}]: {c.content}") |
|
|
| if self._state.suggestions_made: |
| print(f"SUGGESTIONS ({len(self._state.suggestions_made)}):") |
| for s in self._state.suggestions_made: |
| print(f" Line {s.original_line}: {s.suggested_code}") |
|
|
| print(f"VALID ACTIONS: {self.valid_actions()}") |
| print("=" * 60) |
|
|
| def summary(self) -> Dict[str, Any]: |
| if not self.grader or self._state is None: |
| return {} |
|
|
| diagnostics = self.grader.get_diagnostics( |
| comments=self._state.comments_made, |
| suggestions=self._state.suggestions_made, |
| final_decision=self._state.final_decision or "changes_requested", |
| steps_taken=self._state.current_step, |
| max_steps=self.max_steps, |
| ) |
|
|
| print("\n--- Episode Summary ---") |
| print(f" Task : {self.current_task_id}") |
| print(f" Steps taken : {self._state.current_step}/{self.max_steps}") |
| print(f" Final decision : {self._state.final_decision or 'none'}") |
| print(f" Score : {diagnostics['score']}") |
| print(f" Precision : {diagnostics['precision']}") |
| print(f" Recall : {diagnostics['recall']}") |
| print(f" F1 : {diagnostics['f1']}") |
| print(f" True positives : {diagnostics['true_positive_count']}") |
| print(f" False positives : {diagnostics['false_positive_count']}") |
| print(f" False negatives : {diagnostics['false_negative_count']}") |
| print(f" FP penalty : {diagnostics['false_positive_penalty']}") |
| print(f" Efficiency bonus: {diagnostics['efficiency_bonus']}") |
| print("-----------------------") |
|
|
| return diagnostics |
|
|
| def get_task_score(self) -> float: |
| if not self.grader or self._state is None: |
| return 0.0 |
|
|
| return self.grader.compute_score_from_state( |
| comments=self._state.comments_made, |
| suggestions=self._state.suggestions_made, |
| final_decision=self._state.final_decision or "changes_requested", |
| steps_taken=self._state.current_step, |
| max_steps=self.max_steps, |
| ) |
|
|
| def close(self): |
| pass |
|
|
| def state(self) -> Dict[str, Any]: |
| if self._state: |
| return self._state.model_dump() |
| return {} |