100XZX001 commited on
Commit
8846f87
·
verified ·
1 Parent(s): 04362f9

Update environment.py

Browse files
Files changed (1) hide show
  1. environment.py +128 -22
environment.py CHANGED
@@ -1,16 +1,127 @@
1
- from typing import Tuple, Dict, Any
2
  from models import Observation, Action, Reward, State
3
  from grader import grade_comment, grade_question, grade_fix
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class CodeReviewEnv:
6
  def __init__(self, task: str = "easy"):
7
  self.task = task
 
8
  self.reset()
9
 
10
  def set_task(self, task: str):
11
  if task not in ["easy", "medium", "hard", "harder", "hardest"]:
12
  raise ValueError(f"Unknown task: {task}")
13
  self.task = task
 
14
 
15
  def reset(self) -> Observation:
16
  if self.task is None:
@@ -76,6 +187,10 @@ class CodeReviewEnv:
76
  reward = 0.2 # dense bonus for writing
77
  quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.expert_comment)
78
  reward += quality_score
 
 
 
 
79
  self.done = True
80
 
81
  elif action.action_type == "ask_question":
@@ -83,11 +198,11 @@ class CodeReviewEnv:
83
  reward = -0.1
84
  else:
85
  q_score = grade_question(action.question)
86
- reward = 0.1 + q_score # small bonus + quality
87
- # Simulate a helpful answer
88
- answer = self._answer_question(action.question)
89
  self.comments.append(f"Agent: {action.question}")
90
- self.comments.append(f"Env: {answer}")
91
  self.step_count += 1
92
  # Episode continues, not done
93
 
@@ -95,11 +210,14 @@ class CodeReviewEnv:
95
  if not action.fix_code:
96
  reward = -0.2
97
  else:
98
- # We'll use a simple keyword check for demonstration
99
- # In a full version, you'd run unit tests
100
- fix_score = grade_fix(action.fix_code, self.expected_fix_keywords, None)
101
- reward = 0.3 + fix_score
102
- self.test_results = f"Fix evaluated with score {fix_score:.2f}"
 
 
 
103
  self.done = True
104
 
105
  elif action.action_type == "skip":
@@ -116,18 +234,6 @@ class CodeReviewEnv:
116
  obs = self._get_observation()
117
  return obs, Reward(value=reward), self.done, info
118
 
119
- def _answer_question(self, question: str) -> str:
120
- # Simple rule‑based answers – you can expand
121
- q = question.lower()
122
- if "what" in q and "purpose" in q:
123
- return "The purpose of this function is to retrieve a user by ID from a dictionary."
124
- elif "expected" in q:
125
- return "The function should return the user object if the ID exists, otherwise raise a KeyError."
126
- elif "how" in q and "fix" in q:
127
- return "You might consider adding a check for missing keys or using a safer dictionary method like `get`."
128
- else:
129
- return "I'm not sure. Could you be more specific?"
130
-
131
  def _get_observation(self) -> Observation:
132
  return Observation(
133
  pr_title=self.pr_title,
 
1
+ from typing import Tuple, Dict, Any, List, Optional
2
  from models import Observation, Action, Reward, State
3
  from grader import grade_comment, grade_question, grade_fix
4
+ import sys
5
+ import io
6
+ import contextlib
7
 
8
+ # ------------------------- Simulated CI / Unit tests -------------------------
9
+ def run_unit_tests(fix_code: str, task: str) -> float:
10
+ """
11
+ Runs a small set of unit tests for the given task.
12
+ Returns a score in [0,1] based on passed tests.
13
+ """
14
+ # Define tests per task
15
+ test_code = ""
16
+ if task == "easy":
17
+ # Test that the function handles missing keys
18
+ test_code = f"""
19
+ {fix_code}
20
+ def test():
21
+ try:
22
+ users = {{"alice": "Alice"}}
23
+ result = get_user("bob")
24
+ return False # should not get here if key missing
25
+ except KeyError:
26
+ return True # expected: KeyError
27
+ except Exception:
28
+ return False
29
+ """
30
+ elif task == "medium":
31
+ test_code = f"""
32
+ {fix_code}
33
+ def test():
34
+ items = [1,2,3]
35
+ # We cannot directly test the loop, but we can check that 'process' is called correctly.
36
+ # For demonstration, we'll assume the fix uses 'enumerate' or 'for item in'.
37
+ # Here we just check that the code compiles and runs without error.
38
+ try:
39
+ exec(compile("{fix_code}", "<string>", "exec"))
40
+ return True
41
+ except Exception:
42
+ return False
43
+ """
44
+ elif task == "hard":
45
+ test_code = f"""
46
+ {fix_code}
47
+ def test():
48
+ # Test empty list
49
+ try:
50
+ result = calculate_average([])
51
+ return result == 0 # expect 0 or some default
52
+ except ZeroDivisionError:
53
+ return False
54
+ """
55
+ elif task == "harder":
56
+ test_code = f"""
57
+ {fix_code}
58
+ def test():
59
+ # Check that a lock is used
60
+ if "lock" in "{fix_code}".lower():
61
+ return True
62
+ return False
63
+ """
64
+ else: # hardest
65
+ test_code = f"""
66
+ {fix_code}
67
+ def test():
68
+ # Check for lock order mention
69
+ if "same order" in "{fix_code}".lower() or "lock order" in "{fix_code}".lower():
70
+ return True
71
+ return False
72
+ """
73
+
74
+ # Execute the test in a safe sandbox
75
+ try:
76
+ # Capture stdout/stderr
77
+ f = io.StringIO()
78
+ with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
79
+ exec(test_code, {})
80
+ # Check if test function returns True
81
+ local_ns = {}
82
+ exec(test_code, {}, local_ns)
83
+ if 'test' in local_ns and callable(local_ns['test']):
84
+ passed = local_ns['test']()
85
+ return 1.0 if passed else 0.0
86
+ else:
87
+ return 0.0
88
+ except Exception:
89
+ return 0.0
90
+
91
+ # ------------------------- Simulated PR Author -------------------------
92
+ class SimulatedAuthor:
93
+ """Responds to the agent's questions and comments as if they were the PR author."""
94
+ def __init__(self, task: str):
95
+ self.task = task
96
+
97
+ def respond(self, agent_comment: str, agent_question: str = None) -> str:
98
+ if agent_question:
99
+ q = agent_question.lower()
100
+ if "what" in q and "purpose" in q:
101
+ return "The purpose is to retrieve a user safely."
102
+ elif "expected" in q:
103
+ return "It should return the user or raise KeyError."
104
+ else:
105
+ return "Could you be more specific?"
106
+ else:
107
+ # Generic response to a comment
108
+ if "good" in agent_comment.lower():
109
+ return "Thanks for the feedback!"
110
+ else:
111
+ return "I'll consider your suggestion."
112
+
113
+ # ------------------------- Main Environment -------------------------
114
  class CodeReviewEnv:
115
  def __init__(self, task: str = "easy"):
116
  self.task = task
117
+ self.author = None
118
  self.reset()
119
 
120
  def set_task(self, task: str):
121
  if task not in ["easy", "medium", "hard", "harder", "hardest"]:
122
  raise ValueError(f"Unknown task: {task}")
123
  self.task = task
124
+ self.author = SimulatedAuthor(task)
125
 
126
  def reset(self) -> Observation:
127
  if self.task is None:
 
187
  reward = 0.2 # dense bonus for writing
188
  quality_score = grade_comment(self.agent_comment, self.expected_keywords, self.expert_comment)
189
  reward += quality_score
190
+ # Simulate author response
191
+ author_response = self.author.respond(self.agent_comment)
192
+ self.comments.append(f"Agent: {self.agent_comment}")
193
+ self.comments.append(f"Author: {author_response}")
194
  self.done = True
195
 
196
  elif action.action_type == "ask_question":
 
198
  reward = -0.1
199
  else:
200
  q_score = grade_question(action.question)
201
+ reward = 0.1 + q_score
202
+ # Get answer from simulated author
203
+ answer = self.author.respond(agent_question=action.question)
204
  self.comments.append(f"Agent: {action.question}")
205
+ self.comments.append(f"Author: {answer}")
206
  self.step_count += 1
207
  # Episode continues, not done
208
 
 
210
  if not action.fix_code:
211
  reward = -0.2
212
  else:
213
+ # Run CI tests
214
+ test_score = run_unit_tests(action.fix_code, self.task)
215
+ # Also keyword match for partial credit
216
+ kw_score = grade_fix(action.fix_code, self.expected_fix_keywords, None)
217
+ # Combined score: 70% tests, 30% keywords
218
+ combined_score = 0.7 * test_score + 0.3 * kw_score
219
+ reward = 0.3 + combined_score
220
+ self.test_results = f"CI tests passed: {test_score:.0%}, Keywords: {kw_score:.0%}"
221
  self.done = True
222
 
223
  elif action.action_type == "skip":
 
234
  obs = self._get_observation()
235
  return obs, Reward(value=reward), self.done, info
236
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  def _get_observation(self) -> Observation:
238
  return Observation(
239
  pr_title=self.pr_title,