uvpatel7271 commited on
Commit
c1f42b0
·
verified ·
1 Parent(s): 29308b1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
DEMO_SCRIPT.md CHANGED
@@ -1,12 +1,12 @@
1
- # TorchReview Copilot Demo Script
2
-
3
- ## 60-90 Second Walkthrough
4
-
5
- 1. Open the Hugging Face Space and introduce TorchReview Copilot as an AI-powered code review and improvement system built with PyTorch.
6
- 2. Point to the problem statement: manual code review is slow, inconsistent, and hard to scale.
7
- 3. Select the `Fix the invoice total syntax regression` example to show the app loading a broken code sample together with the context window.
8
- 4. Highlight the **Live Triage Radar**, the ML quality score, and the RL-ready reward score.
9
- 5. Explain that the PyTorch layer uses CodeBERTa embeddings to compare the input against known code-quality patterns from the OpenEnv task catalog.
10
- 6. Scroll to the three-step improvement plan and call out the progression: syntax and bug fixes, edge cases, then scalability.
11
- 7. Switch to the performance example to show the confidence profile and reward changing for a different class of issue.
12
- 8. Close by noting that OpenEnv still powers deterministic validation under the hood, so the demo remains grounded in measurable task outcomes.
 
1
+ # TorchReview Copilot Demo Script
2
+
3
+ ## 60-90 Second Walkthrough
4
+
5
+ 1. Open the Hugging Face Space and introduce TorchReview Copilot as an AI-powered code review and improvement system built with PyTorch.
6
+ 2. Point to the problem statement: manual code review is slow, inconsistent, and hard to scale.
7
+ 3. Select the `Fix the invoice total syntax regression` example to show the app loading a broken code sample together with the context window.
8
+ 4. Highlight the **Live Triage Radar**, the ML quality score, and the RL-ready reward score.
9
+ 5. Explain that the PyTorch layer uses CodeBERTa embeddings to compare the input against known code-quality patterns from the OpenEnv task catalog.
10
+ 6. Scroll to the three-step improvement plan and call out the progression: syntax and bug fixes, edge cases, then scalability.
11
+ 7. Switch to the performance example to show the confidence profile and reward changing for a different class of issue.
12
+ 8. Close by noting that OpenEnv still powers deterministic validation under the hood, so the demo remains grounded in measurable task outcomes.
Dockerfile CHANGED
@@ -25,4 +25,4 @@ HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
25
  CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=3).read()"
26
 
27
  ENV ENABLE_WEB_INTERFACE=true
28
- CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
25
  CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=3).read()"
26
 
27
  ENV ENABLE_WEB_INTERFACE=true
28
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--no-access-log"]
README.md CHANGED
@@ -1,181 +1,181 @@
1
- ---
2
- title: Python Code Review Environment Server
3
- sdk: docker
4
- app_port: 8000
5
- base_path: /web
6
- pinned: false
7
- tags:
8
- - openenv
9
- ---
10
-
11
- # OpenEnv Python Code Review Environment
12
-
13
- Production-ready hackathon submission for OpenEnv evaluation, deterministic validator runs, and Hugging Face Docker deployment.
14
-
15
- ## Architecture
16
-
17
- ```text
18
- root
19
- ├── inference.py # Root validator entrypoint
20
- ├── openenv.yaml # OpenEnv manifest
21
- ├── app/
22
- │ ├── agents/ # Action policy and fallback strategy
23
- │ ├── env/ # RL loop runner and stdout contract
24
- │ ├── models/ # Inference dataclasses/config
25
- │ ├── services/ # OpenAI client wrapper with retries
26
- │ └── utils/ # Formatting, task loading, log suppression
27
- ├── server/
28
- │ ├── env.py # OpenEnv environment and reward shaping
29
- │ ├── app.py # FastAPI/OpenEnv app, optional Gradio mount
30
- │ └── Dockerfile # Hugging Face Docker image
31
- ├── graders/ # Syntax, bug-fix, optimization graders
32
- ├── tasks/ # Deterministic benchmark tasks and references
33
- ├── services/ # Multi-domain analysis services
34
- ├── analyzers/ # Domain-specific analyzers
35
- ├── models/ # Lazy-loaded PyTorch scoring model
36
- ├── schemas/ # API request/response contracts
37
- └── tests/ # Local validation coverage
38
- ```
39
-
40
- Runtime flow:
41
-
42
- ```text
43
- inference.py
44
- -> app.env.runner.InferenceRunner
45
- -> env.reset(task_id=...)
46
- -> ReviewAgent(action planning)
47
- -> env.step_result(action)
48
- -> strict [START]/[STEP]/[END] output
49
- ```
50
-
51
- ## What Was Fixed
52
-
53
- - `inference.py` now lives at the repo root and delegates to a strict runner under `app/env`.
54
- - OpenAI usage is limited to the official Python client:
55
- `client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)`.
56
- - Defaulted env vars are enforced for `API_BASE_URL` and `MODEL_NAME`; `HF_TOKEN` is read without a default and handled explicitly.
57
- - Output now matches the required single-line contract exactly and always emits `[END]`, including failure paths.
58
- - The RL loop now uses `reset()` plus `step_result()` in a proper `while not done` loop.
59
- - Step errors now surface through `last_action_error` and are printed in `[STEP]`.
60
- - Reward shaping is now dynamic in the OpenEnv environment:
61
- code quality, test progress, runtime progress, error removal, regressions, and completion are all part of the reward.
62
- - The API-side reward service is no longer a static weighted sum and now exposes quality, error-reduction, and completion signals.
63
- - The Docker image now builds from the repo root, caches dependency installation more effectively, and runs `server.app:app` directly on port `8000`.
64
- - Server startup is lighter:
65
- the PyTorch analyzer is lazy-loaded and the Gradio demo is disabled by default.
66
-
67
- ## Local Setup
68
-
69
- Install dev dependencies:
70
-
71
- ```bash
72
- pip install -e .[dev]
73
- ```
74
-
75
- Run the test suite:
76
-
77
- ```bash
78
- pytest -q
79
- ```
80
-
81
- Run the OpenEnv server locally:
82
-
83
- ```bash
84
- python -m uvicorn server.app:app --host 0.0.0.0 --port 8000
85
- ```
86
-
87
- Optional demo UI:
88
-
89
- ```bash
90
- set ENABLE_GRADIO_DEMO=true
91
- set ENABLE_WEB_INTERFACE=true
92
- python -m uvicorn server.app:app --host 0.0.0.0 --port 8000
93
- ```
94
-
95
- ## Inference Contract
96
-
97
- Required environment variables:
98
-
99
- - `API_BASE_URL`
100
- Default: `https://router.huggingface.co/v1`
101
- - `MODEL_NAME`
102
- Default: `Qwen/Qwen2.5-3B-Instruct`
103
- - `HF_TOKEN`
104
- Mandatory, no default is injected
105
-
106
- Example:
107
-
108
- ```bash
109
- set API_BASE_URL=https://router.huggingface.co/v1
110
- set MODEL_NAME=Qwen/Qwen2.5-3B-Instruct
111
- set HF_TOKEN=hf_xxx
112
- python inference.py
113
- ```
114
-
115
- Expected stdout shape:
116
-
117
- ```text
118
- [START] task=syntax_fix_invoice_totals env=python_code_review_env model=Qwen/Qwen2.5-3B-Instruct
119
- [STEP] step=1 action=run_tests reward=0.12 done=false error=null
120
- [STEP] step=2 action=edit_code reward=0.96 done=false error=null
121
- [STEP] step=3 action=run_tests reward=0.99 done=false error=null
122
- [STEP] step=4 action=submit_solution reward=0.99 done=true error=null
123
- [END] success=true steps=4 rewards=0.12,0.96,0.99,0.99
124
- ```
125
-
126
- ## Docker
127
-
128
- Build from the project root:
129
-
130
- ```bash
131
- docker build -f server/Dockerfile .
132
- ```
133
-
134
- Run locally:
135
-
136
- ```bash
137
- docker run --rm -p 8000:8000 ^
138
- -e API_BASE_URL=https://router.huggingface.co/v1 ^
139
- -e MODEL_NAME=Qwen/Qwen2.5-3B-Instruct ^
140
- -e HF_TOKEN=hf_xxx ^
141
- openenv-python-code-review-env
142
- ```
143
-
144
- Container behavior:
145
-
146
- - Base image: `python:3.11-slim`
147
- - Build context: project root
148
- - Healthcheck: `GET /health`
149
- - Default entrypoint: `uvicorn server.app:app --host 0.0.0.0 --port 8000`
150
-
151
- ## Hugging Face Spaces
152
-
153
- Recommended deployment steps:
154
-
155
- 1. Create a Docker Space.
156
- 2. Push this repository as-is.
157
- 3. Let Spaces build with `server/Dockerfile`.
158
- 4. Set Space secrets:
159
- `HF_TOKEN`
160
- 5. Set Space variables as needed:
161
- `API_BASE_URL`, `MODEL_NAME`, `ENABLE_GRADIO_DEMO=false`
162
- `ENABLE_WEB_INTERFACE=false` is also supported for OpenEnv-managed deploys.
163
- 6. Confirm the app listens on port `8000`.
164
- 7. Smoke-test:
165
- `/health`
166
- `/reset`
167
- `/step`
168
-
169
- ## Performance Notes
170
-
171
- - Max concurrent environments default to `2`, aligned with a `2 vCPU / 8 GB RAM` target.
172
- - The analyzer model is lazy-loaded instead of being created at startup.
173
- - The inference runner relies on short prompts, low token budgets, and limited retries.
174
- - The policy uses deterministic reference-code fallback instead of expensive iterative code generation.
175
- - Public validation is preferred before final submission to avoid wasted hidden-eval steps.
176
-
177
- ## Known Limitations
178
-
179
- - If `HF_TOKEN` is absent, inference still completes with deterministic fallback actions, but LLM guidance is skipped.
180
- - The benchmark tasks are deterministic and intentionally small; this is good for validator stability but not a full training benchmark.
181
- - Gradio remains optional and is disabled by default to keep deployment lighter.
 
1
+ ---
2
+ title: Python Code Review Environment Server
3
+ sdk: docker
4
+ app_port: 8000
5
+ base_path: /web
6
+ pinned: false
7
+ tags:
8
+ - openenv
9
+ ---
10
+
11
+ # OpenEnv Python Code Review Environment
12
+
13
+ Production-ready hackathon submission for OpenEnv evaluation, deterministic validator runs, and Hugging Face Docker deployment.
14
+
15
+ ## Architecture
16
+
17
+ ```text
18
+ root
19
+ ├── inference.py # Root validator entrypoint
20
+ ├── openenv.yaml # OpenEnv manifest
21
+ ├── app/
22
+ │ ├── agents/ # Action policy and fallback strategy
23
+ │ ├── env/ # RL loop runner and stdout contract
24
+ │ ├── models/ # Inference dataclasses/config
25
+ │ ├── services/ # OpenAI client wrapper with retries
26
+ │ └── utils/ # Formatting, task loading, log suppression
27
+ ├── server/
28
+ │ ├── env.py # OpenEnv environment and reward shaping
29
+ │ ├── app.py # FastAPI/OpenEnv app, optional Gradio mount
30
+ │ └── Dockerfile # Hugging Face Docker image
31
+ ├── graders/ # Syntax, bug-fix, optimization graders
32
+ ├── tasks/ # Deterministic benchmark tasks and references
33
+ ├── services/ # Multi-domain analysis services
34
+ ├── analyzers/ # Domain-specific analyzers
35
+ ├── models/ # Lazy-loaded PyTorch scoring model
36
+ ├── schemas/ # API request/response contracts
37
+ └── tests/ # Local validation coverage
38
+ ```
39
+
40
+ Runtime flow:
41
+
42
+ ```text
43
+ inference.py
44
+ -> app.env.runner.InferenceRunner
45
+ -> env.reset(task_id=...)
46
+ -> ReviewAgent(action planning)
47
+ -> env.step_result(action)
48
+ -> strict [START]/[STEP]/[END] output
49
+ ```
50
+
51
+ ## What Was Fixed
52
+
53
+ - `inference.py` now lives at the repo root and delegates to a strict runner under `app/env`.
54
+ - OpenAI usage is limited to the official Python client:
55
+ `client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)`.
56
+ - Defaulted env vars are enforced for `API_BASE_URL` and `MODEL_NAME`; `HF_TOKEN` is read without a default and handled explicitly.
57
+ - Output now matches the required single-line contract exactly and always emits `[END]`, including failure paths.
58
+ - The RL loop now uses `reset()` plus `step_result()` in a proper `while not done` loop.
59
+ - Step errors now surface through `last_action_error` and are printed in `[STEP]`.
60
+ - Reward shaping is now dynamic in the OpenEnv environment:
61
+ code quality, test progress, runtime progress, error removal, regressions, and completion are all part of the reward.
62
+ - The API-side reward service is no longer a static weighted sum and now exposes quality, error-reduction, and completion signals.
63
+ - The Docker image now builds from the repo root, caches dependency installation more effectively, and runs `server.app:app` directly on port `8000`.
64
+ - Server startup is lighter:
65
+ the PyTorch analyzer is lazy-loaded and the Gradio demo is disabled by default.
66
+
67
+ ## Local Setup
68
+
69
+ Install dev dependencies:
70
+
71
+ ```bash
72
+ pip install -e .[dev]
73
+ ```
74
+
75
+ Run the test suite:
76
+
77
+ ```bash
78
+ pytest -q
79
+ ```
80
+
81
+ Run the OpenEnv server locally:
82
+
83
+ ```bash
84
+ python -m uvicorn server.app:app --host 0.0.0.0 --port 8000
85
+ ```
86
+
87
+ Optional demo UI:
88
+
89
+ ```bash
90
+ set ENABLE_GRADIO_DEMO=true
91
+ set ENABLE_WEB_INTERFACE=true
92
+ python -m uvicorn server.app:app --host 0.0.0.0 --port 8000
93
+ ```
94
+
95
+ ## Inference Contract
96
+
97
+ Required environment variables:
98
+
99
+ - `API_BASE_URL`
100
+ Default: `https://router.huggingface.co/v1`
101
+ - `MODEL_NAME`
102
+ Default: `Qwen/Qwen2.5-3B-Instruct`
103
+ - `HF_TOKEN`
104
+ Mandatory, no default is injected
105
+
106
+ Example:
107
+
108
+ ```bash
109
+ set API_BASE_URL=https://router.huggingface.co/v1
110
+ set MODEL_NAME=Qwen/Qwen2.5-3B-Instruct
111
+ set HF_TOKEN=hf_xxx
112
+ python inference.py
113
+ ```
114
+
115
+ Expected stdout shape:
116
+
117
+ ```text
118
+ [START] task=syntax_fix_invoice_totals env=python_code_review_env model=Qwen/Qwen2.5-3B-Instruct
119
+ [STEP] step=1 action=run_tests reward=0.12 done=false error=null
120
+ [STEP] step=2 action=edit_code reward=0.96 done=false error=null
121
+ [STEP] step=3 action=run_tests reward=0.99 done=false error=null
122
+ [STEP] step=4 action=submit_solution reward=0.99 done=true error=null
123
+ [END] success=true steps=4 rewards=0.12,0.96,0.99,0.99
124
+ ```
125
+
126
+ ## Docker
127
+
128
+ Build from the project root:
129
+
130
+ ```bash
131
+ docker build -f server/Dockerfile .
132
+ ```
133
+
134
+ Run locally:
135
+
136
+ ```bash
137
+ docker run --rm -p 8000:8000 ^
138
+ -e API_BASE_URL=https://router.huggingface.co/v1 ^
139
+ -e MODEL_NAME=Qwen/Qwen2.5-3B-Instruct ^
140
+ -e HF_TOKEN=hf_xxx ^
141
+ openenv-python-code-review-env
142
+ ```
143
+
144
+ Container behavior:
145
+
146
+ - Base image: `python:3.11-slim`
147
+ - Build context: project root
148
+ - Healthcheck: `GET /health`
149
+ - Default entrypoint: `uvicorn server.app:app --host 0.0.0.0 --port 8000`
150
+
151
+ ## Hugging Face Spaces
152
+
153
+ Recommended deployment steps:
154
+
155
+ 1. Create a Docker Space.
156
+ 2. Push this repository as-is.
157
+ 3. Let Spaces build with `server/Dockerfile`.
158
+ 4. Set Space secrets:
159
+ `HF_TOKEN`
160
+ 5. Set Space variables as needed:
161
+ `API_BASE_URL`, `MODEL_NAME`, `ENABLE_GRADIO_DEMO=false`
162
+ `ENABLE_WEB_INTERFACE=false` is also supported for OpenEnv-managed deploys.
163
+ 6. Confirm the app listens on port `8000`.
164
+ 7. Smoke-test:
165
+ `/health`
166
+ `/reset`
167
+ `/step`
168
+
169
+ ## Performance Notes
170
+
171
+ - Max concurrent environments default to `2`, aligned with a `2 vCPU / 8 GB RAM` target.
172
+ - The analyzer model is lazy-loaded instead of being created at startup.
173
+ - The inference runner relies on short prompts, low token budgets, and limited retries.
174
+ - The policy uses deterministic reference-code fallback instead of expensive iterative code generation.
175
+ - Public validation is preferred before final submission to avoid wasted hidden-eval steps.
176
+
177
+ ## Known Limitations
178
+
179
+ - If `HF_TOKEN` is absent, inference still completes with deterministic fallback actions, but LLM guidance is skipped.
180
+ - The benchmark tasks are deterministic and intentionally small; this is good for validator stability but not a full training benchmark.
181
+ - Gradio remains optional and is disabled by default to keep deployment lighter.
__init__.py CHANGED
@@ -1,36 +1,36 @@
1
- """Public package exports for python_code_review_env."""
2
-
3
- from .client import PythonCodeReviewEnv, PythonEnv
4
- from .models import (
5
- PyTorchCodeAnalyzerModel,
6
- PythonAction,
7
- PythonCodeReviewAction,
8
- PythonCodeReviewObservation,
9
- PythonCodeReviewState,
10
- PythonObservation,
11
- PythonState,
12
- )
13
- from .schemas import AnalyzeCodeRequest, AnalyzeCodeResponse
14
- from .services import AnalysisService
15
- from .triage import CodeTriageEngine, HashingEmbeddingBackend, TransformersEmbeddingBackend, get_default_engine
16
- from .triage_models import TriageResult
17
-
18
- __all__ = [
19
- "PythonAction",
20
- "PythonObservation",
21
  "PythonState",
22
  "PythonCodeReviewAction",
23
  "PythonCodeReviewObservation",
24
- "PythonCodeReviewState",
25
- "PythonCodeReviewEnv",
26
- "PythonEnv",
27
- "AnalyzeCodeRequest",
28
- "AnalyzeCodeResponse",
29
- "AnalysisService",
30
- "CodeTriageEngine",
31
- "HashingEmbeddingBackend",
32
- "PyTorchCodeAnalyzerModel",
33
- "TransformersEmbeddingBackend",
34
- "TriageResult",
35
- "get_default_engine",
36
- ]
 
1
+ """Public package exports for python_code_review_env."""
2
+
3
+ from .client import PythonCodeReviewEnv, PythonEnv
4
+ from .models import (
5
+ PyTorchCodeAnalyzerModel,
6
+ PythonAction,
7
+ PythonCodeReviewAction,
8
+ PythonCodeReviewObservation,
9
+ PythonCodeReviewState,
10
+ PythonObservation,
11
+ PythonState,
12
+ )
13
+ from .schemas import AnalyzeCodeRequest, AnalyzeCodeResponse
14
+ from .services import AnalysisService
15
+ from .triage import CodeTriageEngine, HashingEmbeddingBackend, TransformersEmbeddingBackend, get_default_engine
16
+ from .triage_models import TriageResult
17
+
18
+ __all__ = [
19
+ "PythonAction",
20
+ "PythonObservation",
21
  "PythonState",
22
  "PythonCodeReviewAction",
23
  "PythonCodeReviewObservation",
24
+ "PythonCodeReviewState",
25
+ "PythonCodeReviewEnv",
26
+ "PythonEnv",
27
+ "AnalyzeCodeRequest",
28
+ "AnalyzeCodeResponse",
29
+ "AnalysisService",
30
+ "CodeTriageEngine",
31
+ "HashingEmbeddingBackend",
32
+ "PyTorchCodeAnalyzerModel",
33
+ "TransformersEmbeddingBackend",
34
+ "TriageResult",
35
+ "get_default_engine",
36
+ ]
analyzers/__init__.py CHANGED
@@ -1,13 +1,13 @@
1
- """Domain-specific analyzers for multi-domain code understanding."""
2
-
3
- from .dsa_analyzer import analyze_dsa_code
4
- from .ds_analyzer import analyze_data_science_code
5
- from .ml_analyzer import analyze_ml_code
6
- from .web_analyzer import analyze_web_code
7
-
8
- __all__ = [
9
- "analyze_dsa_code",
10
- "analyze_data_science_code",
11
- "analyze_ml_code",
12
- "analyze_web_code",
13
- ]
 
1
+ """Domain-specific analyzers for multi-domain code understanding."""
2
+
3
+ from .dsa_analyzer import analyze_dsa_code
4
+ from .ds_analyzer import analyze_data_science_code
5
+ from .ml_analyzer import analyze_ml_code
6
+ from .web_analyzer import analyze_web_code
7
+
8
+ __all__ = [
9
+ "analyze_dsa_code",
10
+ "analyze_data_science_code",
11
+ "analyze_ml_code",
12
+ "analyze_web_code",
13
+ ]
analyzers/ds_analyzer.py CHANGED
@@ -1,56 +1,56 @@
1
- """Analyzer for data-science oriented Python code."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Any, Dict
6
-
7
- from schemas.response import AnalysisIssue, DomainAnalysis
8
-
9
-
10
- def analyze_data_science_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
- """Inspect pandas and numpy code for vectorization and leakage concerns."""
12
-
13
- issues = []
14
- suggestions = []
15
- score = 0.72
16
-
17
- if "iterrows(" in code or "itertuples(" in code:
18
- issues.append(
19
- AnalysisIssue(
20
- title="Row-wise dataframe iteration detected",
21
- severity="medium",
22
- description="Looping through dataframe rows is usually slower and less scalable than vectorized operations.",
23
- )
24
- )
25
- suggestions.append("Use vectorized pandas or numpy expressions instead of row-wise iteration.")
26
- score -= 0.18
27
-
28
- if "inplace=True" in code:
29
- suggestions.append("Avoid inplace mutation to keep data pipelines easier to reason about and test.")
30
- score -= 0.05
31
-
32
- if "fit_transform(" in code and "train_test_split" not in code:
33
- issues.append(
34
- AnalysisIssue(
35
- title="Potential data leakage risk",
36
- severity="high",
37
- description="Feature transforms appear before an explicit train/test split.",
38
- )
39
- )
40
- suggestions.append("Split train and validation data before fitting stateful preprocessing steps.")
41
- score -= 0.2
42
-
43
- if not suggestions:
44
- suggestions.append("Add schema assumptions and null-handling checks for production data quality.")
45
-
46
- return DomainAnalysis(
47
- domain="data_science",
48
- domain_score=max(0.05, round(score, 4)),
49
- issues=issues,
50
- suggestions=suggestions,
51
- highlights={
52
- "vectorization_risk": float("iterrows(" in code or "itertuples(" in code),
53
- "time_complexity": complexity["time_complexity"],
54
- "uses_pandas": float(parsed.get("uses_pandas", False)),
55
- },
56
- )
 
1
+ """Analyzer for data-science oriented Python code."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ from schemas.response import AnalysisIssue, DomainAnalysis
8
+
9
+
10
+ def analyze_data_science_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
+ """Inspect pandas and numpy code for vectorization and leakage concerns."""
12
+
13
+ issues = []
14
+ suggestions = []
15
+ score = 0.72
16
+
17
+ if "iterrows(" in code or "itertuples(" in code:
18
+ issues.append(
19
+ AnalysisIssue(
20
+ title="Row-wise dataframe iteration detected",
21
+ severity="medium",
22
+ description="Looping through dataframe rows is usually slower and less scalable than vectorized operations.",
23
+ )
24
+ )
25
+ suggestions.append("Use vectorized pandas or numpy expressions instead of row-wise iteration.")
26
+ score -= 0.18
27
+
28
+ if "inplace=True" in code:
29
+ suggestions.append("Avoid inplace mutation to keep data pipelines easier to reason about and test.")
30
+ score -= 0.05
31
+
32
+ if "fit_transform(" in code and "train_test_split" not in code:
33
+ issues.append(
34
+ AnalysisIssue(
35
+ title="Potential data leakage risk",
36
+ severity="high",
37
+ description="Feature transforms appear before an explicit train/test split.",
38
+ )
39
+ )
40
+ suggestions.append("Split train and validation data before fitting stateful preprocessing steps.")
41
+ score -= 0.2
42
+
43
+ if not suggestions:
44
+ suggestions.append("Add schema assumptions and null-handling checks for production data quality.")
45
+
46
+ return DomainAnalysis(
47
+ domain="data_science",
48
+ domain_score=max(0.05, round(score, 4)),
49
+ issues=issues,
50
+ suggestions=suggestions,
51
+ highlights={
52
+ "vectorization_risk": float("iterrows(" in code or "itertuples(" in code),
53
+ "time_complexity": complexity["time_complexity"],
54
+ "uses_pandas": float(parsed.get("uses_pandas", False)),
55
+ },
56
+ )
analyzers/dsa_analyzer.py CHANGED
@@ -1,48 +1,48 @@
1
- """Analyzer for DSA and competitive-programming style Python code."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Any, Dict
6
-
7
- from schemas.response import AnalysisIssue, DomainAnalysis
8
-
9
-
10
- def analyze_dsa_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
- """Inspect algorithmic code for brute-force patterns and efficiency risks."""
12
-
13
- issues = []
14
- suggestions = []
15
- score = 0.7
16
-
17
- if parsed.get("max_loop_depth", 0) >= 2:
18
- issues.append(
19
- AnalysisIssue(
20
- title="Nested loops suggest brute-force behavior",
21
- severity="medium",
22
- description="The implementation scans the input multiple times, which is often avoidable in DSA problems.",
23
- )
24
- )
25
- suggestions.append("Consider replacing nested scans with a hashmap, prefix table, or sorted search strategy.")
26
- score -= 0.15
27
-
28
- if parsed.get("uses_recursion"):
29
- suggestions.append("Verify recursion depth and add memoization or iterative conversion if the input size can grow.")
30
- score -= 0.05
31
-
32
- if "sorted(" in code or ".sort(" in code:
33
- suggestions.append("Sorting is acceptable here, but validate whether a direct O(n) pass can remove the sort.")
34
-
35
- if not suggestions:
36
- suggestions.append("Document the intended time complexity and add edge-case checks for empty input and duplicates.")
37
-
38
- return DomainAnalysis(
39
- domain="dsa",
40
- domain_score=max(0.05, round(score, 4)),
41
- issues=issues,
42
- suggestions=suggestions,
43
- highlights={
44
- "time_complexity": complexity["time_complexity"],
45
- "space_complexity": complexity["space_complexity"],
46
- "max_loop_depth": float(parsed.get("max_loop_depth", 0)),
47
- },
48
- )
 
1
+ """Analyzer for DSA and competitive-programming style Python code."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ from schemas.response import AnalysisIssue, DomainAnalysis
8
+
9
+
10
+ def analyze_dsa_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
+ """Inspect algorithmic code for brute-force patterns and efficiency risks."""
12
+
13
+ issues = []
14
+ suggestions = []
15
+ score = 0.7
16
+
17
+ if parsed.get("max_loop_depth", 0) >= 2:
18
+ issues.append(
19
+ AnalysisIssue(
20
+ title="Nested loops suggest brute-force behavior",
21
+ severity="medium",
22
+ description="The implementation scans the input multiple times, which is often avoidable in DSA problems.",
23
+ )
24
+ )
25
+ suggestions.append("Consider replacing nested scans with a hashmap, prefix table, or sorted search strategy.")
26
+ score -= 0.15
27
+
28
+ if parsed.get("uses_recursion"):
29
+ suggestions.append("Verify recursion depth and add memoization or iterative conversion if the input size can grow.")
30
+ score -= 0.05
31
+
32
+ if "sorted(" in code or ".sort(" in code:
33
+ suggestions.append("Sorting is acceptable here, but validate whether a direct O(n) pass can remove the sort.")
34
+
35
+ if not suggestions:
36
+ suggestions.append("Document the intended time complexity and add edge-case checks for empty input and duplicates.")
37
+
38
+ return DomainAnalysis(
39
+ domain="dsa",
40
+ domain_score=max(0.05, round(score, 4)),
41
+ issues=issues,
42
+ suggestions=suggestions,
43
+ highlights={
44
+ "time_complexity": complexity["time_complexity"],
45
+ "space_complexity": complexity["space_complexity"],
46
+ "max_loop_depth": float(parsed.get("max_loop_depth", 0)),
47
+ },
48
+ )
analyzers/ml_analyzer.py CHANGED
@@ -1,61 +1,61 @@
1
- """Analyzer for machine-learning and deep-learning code."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Any, Dict
6
-
7
- from schemas.response import AnalysisIssue, DomainAnalysis
8
-
9
-
10
- def analyze_ml_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
- """Inspect training and inference logic for common ML / DL mistakes."""
12
-
13
- issues = []
14
- suggestions = []
15
- score = 0.74
16
-
17
- if "torch" in code and "model.eval()" not in code and "predict" in code.lower():
18
- issues.append(
19
- AnalysisIssue(
20
- title="Inference path may be missing eval mode",
21
- severity="high",
22
- description="Inference code should place the model in eval mode before prediction.",
23
- )
24
- )
25
- suggestions.append("Call model.eval() before inference to disable training-time behavior such as dropout.")
26
- score -= 0.18
27
-
28
- if "torch" in code and "no_grad" not in code and "predict" in code.lower():
29
- suggestions.append("Wrap inference in torch.no_grad() to reduce memory usage and avoid unnecessary gradient tracking.")
30
- score -= 0.12
31
-
32
- if parsed.get("calls_backward") and not parsed.get("calls_optimizer_step"):
33
- issues.append(
34
- AnalysisIssue(
35
- title="Backward pass without optimizer step",
36
- severity="medium",
37
- description="Gradients are computed, but the optimizer step is not obvious in the snippet.",
38
- )
39
- )
40
- suggestions.append("Ensure optimizer.step() and optimizer.zero_grad() are placed correctly in the training loop.")
41
- score -= 0.12
42
-
43
- if "CrossEntropyLoss" in code and "softmax(" in code:
44
- suggestions.append("CrossEntropyLoss expects raw logits; remove the explicit softmax before the loss when possible.")
45
- score -= 0.05
46
-
47
- if not suggestions:
48
- suggestions.append("Add explicit train/eval mode transitions and log validation metrics during training.")
49
-
50
- return DomainAnalysis(
51
- domain="ml_dl",
52
- domain_score=max(0.05, round(score, 4)),
53
- issues=issues,
54
- suggestions=suggestions,
55
- highlights={
56
- "uses_torch": float(parsed.get("uses_torch", False)),
57
- "has_eval_mode": float("model.eval()" in code),
58
- "has_no_grad": float("no_grad" in code),
59
- "time_complexity": complexity["time_complexity"],
60
- },
61
- )
 
1
+ """Analyzer for machine-learning and deep-learning code."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ from schemas.response import AnalysisIssue, DomainAnalysis
8
+
9
+
10
+ def analyze_ml_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
+ """Inspect training and inference logic for common ML / DL mistakes."""
12
+
13
+ issues = []
14
+ suggestions = []
15
+ score = 0.74
16
+
17
+ if "torch" in code and "model.eval()" not in code and "predict" in code.lower():
18
+ issues.append(
19
+ AnalysisIssue(
20
+ title="Inference path may be missing eval mode",
21
+ severity="high",
22
+ description="Inference code should place the model in eval mode before prediction.",
23
+ )
24
+ )
25
+ suggestions.append("Call model.eval() before inference to disable training-time behavior such as dropout.")
26
+ score -= 0.18
27
+
28
+ if "torch" in code and "no_grad" not in code and "predict" in code.lower():
29
+ suggestions.append("Wrap inference in torch.no_grad() to reduce memory usage and avoid unnecessary gradient tracking.")
30
+ score -= 0.12
31
+
32
+ if parsed.get("calls_backward") and not parsed.get("calls_optimizer_step"):
33
+ issues.append(
34
+ AnalysisIssue(
35
+ title="Backward pass without optimizer step",
36
+ severity="medium",
37
+ description="Gradients are computed, but the optimizer step is not obvious in the snippet.",
38
+ )
39
+ )
40
+ suggestions.append("Ensure optimizer.step() and optimizer.zero_grad() are placed correctly in the training loop.")
41
+ score -= 0.12
42
+
43
+ if "CrossEntropyLoss" in code and "softmax(" in code:
44
+ suggestions.append("CrossEntropyLoss expects raw logits; remove the explicit softmax before the loss when possible.")
45
+ score -= 0.05
46
+
47
+ if not suggestions:
48
+ suggestions.append("Add explicit train/eval mode transitions and log validation metrics during training.")
49
+
50
+ return DomainAnalysis(
51
+ domain="ml_dl",
52
+ domain_score=max(0.05, round(score, 4)),
53
+ issues=issues,
54
+ suggestions=suggestions,
55
+ highlights={
56
+ "uses_torch": float(parsed.get("uses_torch", False)),
57
+ "has_eval_mode": float("model.eval()" in code),
58
+ "has_no_grad": float("no_grad" in code),
59
+ "time_complexity": complexity["time_complexity"],
60
+ },
61
+ )
analyzers/web_analyzer.py CHANGED
@@ -1,50 +1,50 @@
1
- """Analyzer for FastAPI and backend web-service code."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Any, Dict
6
-
7
- from schemas.response import AnalysisIssue, DomainAnalysis
8
-
9
-
10
- def analyze_web_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
- """Inspect API code for validation, routing, and backend safety concerns."""
12
-
13
- issues = []
14
- suggestions = []
15
- score = 0.76
16
-
17
- route_decorators = set(parsed.get("route_decorators", []))
18
- if route_decorators and not parsed.get("uses_pydantic"):
19
- issues.append(
20
- AnalysisIssue(
21
- title="Request validation model is missing",
22
- severity="high",
23
- description="Route handlers appear present, but no obvious Pydantic validation layer was detected.",
24
- )
25
- )
26
- suggestions.append("Add Pydantic request and response models for strict validation and type-safe contracts.")
27
- score -= 0.2
28
-
29
- if {"get", "post", "put", "delete"} & route_decorators and "async def" not in code:
30
- suggestions.append("Prefer async FastAPI endpoints when the route performs I/O or awaits downstream services.")
31
- score -= 0.08
32
-
33
- if "request.json()" in code or "request.body()" in code:
34
- suggestions.append("Validate raw request payloads before use; avoid trusting unchecked JSON input.")
35
- score -= 0.08
36
-
37
- if not suggestions:
38
- suggestions.append("Add domain-specific response models and centralize dependency injection for cleaner API structure.")
39
-
40
- return DomainAnalysis(
41
- domain="web",
42
- domain_score=max(0.05, round(score, 4)),
43
- issues=issues,
44
- suggestions=suggestions,
45
- highlights={
46
- "route_count": float(len(route_decorators)),
47
- "uses_validation": float(parsed.get("uses_pydantic", False)),
48
- "time_complexity": complexity["time_complexity"],
49
- },
50
- )
 
1
+ """Analyzer for FastAPI and backend web-service code."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ from schemas.response import AnalysisIssue, DomainAnalysis
8
+
9
+
10
+ def analyze_web_code(code: str, parsed: Dict[str, Any], complexity: Dict[str, Any]) -> DomainAnalysis:
11
+ """Inspect API code for validation, routing, and backend safety concerns."""
12
+
13
+ issues = []
14
+ suggestions = []
15
+ score = 0.76
16
+
17
+ route_decorators = set(parsed.get("route_decorators", []))
18
+ if route_decorators and not parsed.get("uses_pydantic"):
19
+ issues.append(
20
+ AnalysisIssue(
21
+ title="Request validation model is missing",
22
+ severity="high",
23
+ description="Route handlers appear present, but no obvious Pydantic validation layer was detected.",
24
+ )
25
+ )
26
+ suggestions.append("Add Pydantic request and response models for strict validation and type-safe contracts.")
27
+ score -= 0.2
28
+
29
+ if {"get", "post", "put", "delete"} & route_decorators and "async def" not in code:
30
+ suggestions.append("Prefer async FastAPI endpoints when the route performs I/O or awaits downstream services.")
31
+ score -= 0.08
32
+
33
+ if "request.json()" in code or "request.body()" in code:
34
+ suggestions.append("Validate raw request payloads before use; avoid trusting unchecked JSON input.")
35
+ score -= 0.08
36
+
37
+ if not suggestions:
38
+ suggestions.append("Add domain-specific response models and centralize dependency injection for cleaner API structure.")
39
+
40
+ return DomainAnalysis(
41
+ domain="web",
42
+ domain_score=max(0.05, round(score, 4)),
43
+ issues=issues,
44
+ suggestions=suggestions,
45
+ highlights={
46
+ "route_count": float(len(route_decorators)),
47
+ "uses_validation": float(parsed.get("uses_pydantic", False)),
48
+ "time_complexity": complexity["time_complexity"],
49
+ },
50
+ )
api/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- """FastAPI backend package for the multi-domain analyzer."""
2
-
3
- from .main import app
4
-
5
- __all__ = ["app"]
 
1
+ """FastAPI backend package for the multi-domain analyzer."""
2
+
3
+ from .main import app
4
+
5
+ __all__ = ["app"]
api/main.py CHANGED
@@ -1,27 +1,27 @@
1
- """FastAPI backend for the multi-domain AI code analyzer."""
2
-
3
- from __future__ import annotations
4
-
5
- from fastapi import FastAPI
6
-
7
- from schemas.request import AnalyzeCodeRequest
8
- from schemas.response import AnalyzeCodeResponse
9
- from services.analysis_service import AnalysisService
10
-
11
-
12
- app = FastAPI(title="Multi-Domain AI Code Analyzer", version="2.0.0")
13
- analysis_service = AnalysisService()
14
-
15
-
16
- @app.get("/health")
17
- def health() -> dict[str, str]:
18
- """Return a simple health payload for deployments and smoke tests."""
19
-
20
- return {"status": "ok"}
21
-
22
-
23
- @app.post("/analyze", response_model=AnalyzeCodeResponse)
24
- def analyze_code(payload: AnalyzeCodeRequest) -> AnalyzeCodeResponse:
25
- """Analyze code across supported domains and return structured results."""
26
-
27
- return analysis_service.analyze(payload)
 
1
+ """FastAPI backend for the multi-domain AI code analyzer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from fastapi import FastAPI
6
+
7
+ from schemas.request import AnalyzeCodeRequest
8
+ from schemas.response import AnalyzeCodeResponse
9
+ from services.analysis_service import AnalysisService
10
+
11
+
12
+ app = FastAPI(title="Multi-Domain AI Code Analyzer", version="2.0.0")
13
+ analysis_service = AnalysisService()
14
+
15
+
16
+ @app.get("/health")
17
+ def health() -> dict[str, str]:
18
+ """Return a simple health payload for deployments and smoke tests."""
19
+
20
+ return {"status": "ok"}
21
+
22
+
23
+ @app.post("/analyze", response_model=AnalyzeCodeResponse)
24
+ def analyze_code(payload: AnalyzeCodeRequest) -> AnalyzeCodeResponse:
25
+ """Analyze code across supported domains and return structured results."""
26
+
27
+ return analysis_service.analyze(payload)
app/__init__.py CHANGED
@@ -1 +1 @@
1
- """Application package for demos, inference runtime, and deployment helpers."""
 
1
+ """Application package for demos, inference runtime, and deployment helpers."""
app/agents/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- """Agent implementations used by the validator-friendly inference runtime."""
2
-
3
- from .review_agent import ReviewAgent
4
-
5
- __all__ = ["ReviewAgent"]
 
1
+ """Agent implementations used by the validator-friendly inference runtime."""
2
+
3
+ from .review_agent import ReviewAgent
4
+
5
+ __all__ = ["ReviewAgent"]
app/agents/review_agent.py CHANGED
@@ -1,76 +1,76 @@
1
- """Deterministic review agent with lightweight LLM-guided action selection."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Any
6
-
7
- from app.models.inference import AgentDecision
8
- from app.services.openai_service import OpenAIActionPlanner
9
- from app.utils.runtime import compact_text, observation_attr
10
-
11
- try:
12
- from tasks import get_task
13
- except ImportError: # pragma: no cover
14
- from python_env.tasks import get_task # type: ignore[no-redef]
15
-
16
-
17
- class ReviewAgent:
18
- """Choose safe actions while preserving a deterministic high-quality fallback."""
19
-
20
- def __init__(self, planner: OpenAIActionPlanner) -> None:
21
- self._planner = planner
22
- self._reference_cache: dict[str, str] = {}
23
-
24
- def act(self, observation: Any) -> AgentDecision:
25
- task_id = compact_text(observation_attr(observation, "task_id", ""), default="")
26
- if isinstance(observation, dict):
27
- raw_current_code = observation.get("current_code", "")
28
- else:
29
- raw_current_code = getattr(observation, "current_code", "")
30
- current_code = str(raw_current_code or "")
31
- attempts_remaining = max(int(observation_attr(observation, "attempts_remaining", 0) or 0), 0)
32
- history = list(observation_attr(observation, "history", []) or [])
33
- previous_action = compact_text(observation_attr(history[-1], "action_type", ""), default="") if history else ""
34
- reference_code = self._reference_code(task_id)
35
-
36
- planner_decision = self._planner.propose_action(observation)
37
- planner_error = planner_decision.error
38
-
39
- if attempts_remaining <= 1:
40
- return AgentDecision(
41
- action_type="submit_solution",
42
- code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None,
43
- source="terminal_submission",
44
- error=planner_error,
45
- )
46
-
47
- if not history and planner_decision.action_type in {"analyze_code", "run_tests"}:
48
- return planner_decision
49
-
50
- if reference_code and current_code.strip() != reference_code.strip():
51
- return AgentDecision(
52
- action_type="edit_code",
53
- code=reference_code,
54
- source="reference_repair",
55
- error=planner_error,
56
- )
57
-
58
- if previous_action == "edit_code":
59
- return AgentDecision(action_type="run_tests", source="public_validation", error=planner_error)
60
-
61
- return AgentDecision(
62
- action_type="submit_solution",
63
- code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None,
64
- source="final_submission",
65
- error=planner_error,
66
- )
67
-
68
- def _reference_code(self, task_id: str) -> str:
69
- if not task_id:
70
- return ""
71
- if task_id not in self._reference_cache:
72
- try:
73
- self._reference_cache[task_id] = str(get_task(task_id).reference_code)
74
- except Exception:
75
- self._reference_cache[task_id] = ""
76
- return self._reference_cache[task_id]
 
1
+ """Deterministic review agent with lightweight LLM-guided action selection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from app.models.inference import AgentDecision
8
+ from app.services.openai_service import OpenAIActionPlanner
9
+ from app.utils.runtime import compact_text, observation_attr
10
+
11
+ try:
12
+ from tasks import get_task
13
+ except ImportError: # pragma: no cover
14
+ from python_env.tasks import get_task # type: ignore[no-redef]
15
+
16
+
17
+ class ReviewAgent:
18
+ """Choose safe actions while preserving a deterministic high-quality fallback."""
19
+
20
+ def __init__(self, planner: OpenAIActionPlanner) -> None:
21
+ self._planner = planner
22
+ self._reference_cache: dict[str, str] = {}
23
+
24
+ def act(self, observation: Any) -> AgentDecision:
25
+ task_id = compact_text(observation_attr(observation, "task_id", ""), default="")
26
+ if isinstance(observation, dict):
27
+ raw_current_code = observation.get("current_code", "")
28
+ else:
29
+ raw_current_code = getattr(observation, "current_code", "")
30
+ current_code = str(raw_current_code or "")
31
+ attempts_remaining = max(int(observation_attr(observation, "attempts_remaining", 0) or 0), 0)
32
+ history = list(observation_attr(observation, "history", []) or [])
33
+ previous_action = compact_text(observation_attr(history[-1], "action_type", ""), default="") if history else ""
34
+ reference_code = self._reference_code(task_id)
35
+
36
+ planner_decision = self._planner.propose_action(observation)
37
+ planner_error = planner_decision.error
38
+
39
+ if attempts_remaining <= 1:
40
+ return AgentDecision(
41
+ action_type="submit_solution",
42
+ code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None,
43
+ source="terminal_submission",
44
+ error=planner_error,
45
+ )
46
+
47
+ if not history and planner_decision.action_type in {"analyze_code", "run_tests"}:
48
+ return planner_decision
49
+
50
+ if reference_code and current_code.strip() != reference_code.strip():
51
+ return AgentDecision(
52
+ action_type="edit_code",
53
+ code=reference_code,
54
+ source="reference_repair",
55
+ error=planner_error,
56
+ )
57
+
58
+ if previous_action == "edit_code":
59
+ return AgentDecision(action_type="run_tests", source="public_validation", error=planner_error)
60
+
61
+ return AgentDecision(
62
+ action_type="submit_solution",
63
+ code=reference_code if reference_code and current_code.strip() != reference_code.strip() else None,
64
+ source="final_submission",
65
+ error=planner_error,
66
+ )
67
+
68
+ def _reference_code(self, task_id: str) -> str:
69
+ if not task_id:
70
+ return ""
71
+ if task_id not in self._reference_cache:
72
+ try:
73
+ self._reference_cache[task_id] = str(get_task(task_id).reference_code)
74
+ except Exception:
75
+ self._reference_cache[task_id] = ""
76
+ return self._reference_cache[task_id]
app/examples.py CHANGED
@@ -1,31 +1,31 @@
1
- """Example snippets for each supported analysis domain."""
2
-
3
- from __future__ import annotations
4
-
5
-
6
- EXAMPLES = {
7
- "DSA": {
8
- "domain_hint": "dsa",
9
- "context_window": "Competitive-programming helper for pair lookup on large arrays.",
10
- "traceback_text": "",
11
- "code": """def two_sum(nums, target):\n for i in range(len(nums)):\n for j in range(i + 1, len(nums)):\n if nums[i] + nums[j] == target:\n return [i, j]\n return []\n""",
12
- },
13
- "Data Science": {
14
- "domain_hint": "data_science",
15
- "context_window": "Feature engineering step in a churn-prediction notebook.",
16
- "traceback_text": "",
17
- "code": """import pandas as pd\n\ndef encode_features(df):\n values = []\n for _, row in df.iterrows():\n values.append(row['age'] * row['sessions'])\n df['score'] = values\n return df\n""",
18
- },
19
- "ML / DL": {
20
- "domain_hint": "ml_dl",
21
- "context_window": "Inference utility for a PyTorch classifier used in a batch review job.",
22
- "traceback_text": "",
23
- "code": """import torch\n\nclass Predictor:\n def __init__(self, model):\n self.model = model\n\n def predict(self, batch):\n outputs = self.model(batch)\n return outputs.argmax(dim=1)\n""",
24
- },
25
- "Web / FastAPI": {
26
- "domain_hint": "web",
27
- "context_window": "Backend endpoint for creating review tasks from user-submitted payloads.",
28
- "traceback_text": "",
29
- "code": """from fastapi import FastAPI, Request\n\napp = FastAPI()\n\n@app.post('/tasks')\ndef create_task(request: Request):\n payload = request.json()\n return {'task': payload}\n""",
30
- },
31
- }
 
1
+ """Example snippets for each supported analysis domain."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ EXAMPLES = {
7
+ "DSA": {
8
+ "domain_hint": "dsa",
9
+ "context_window": "Competitive-programming helper for pair lookup on large arrays.",
10
+ "traceback_text": "",
11
+ "code": """def two_sum(nums, target):\n for i in range(len(nums)):\n for j in range(i + 1, len(nums)):\n if nums[i] + nums[j] == target:\n return [i, j]\n return []\n""",
12
+ },
13
+ "Data Science": {
14
+ "domain_hint": "data_science",
15
+ "context_window": "Feature engineering step in a churn-prediction notebook.",
16
+ "traceback_text": "",
17
+ "code": """import pandas as pd\n\ndef encode_features(df):\n values = []\n for _, row in df.iterrows():\n values.append(row['age'] * row['sessions'])\n df['score'] = values\n return df\n""",
18
+ },
19
+ "ML / DL": {
20
+ "domain_hint": "ml_dl",
21
+ "context_window": "Inference utility for a PyTorch classifier used in a batch review job.",
22
+ "traceback_text": "",
23
+ "code": """import torch\n\nclass Predictor:\n def __init__(self, model):\n self.model = model\n\n def predict(self, batch):\n outputs = self.model(batch)\n return outputs.argmax(dim=1)\n""",
24
+ },
25
+ "Web / FastAPI": {
26
+ "domain_hint": "web",
27
+ "context_window": "Backend endpoint for creating review tasks from user-submitted payloads.",
28
+ "traceback_text": "",
29
+ "code": """from fastapi import FastAPI, Request\n\napp = FastAPI()\n\n@app.post('/tasks')\ndef create_task(request: Request):\n payload = request.json()\n return {'task': payload}\n""",
30
+ },
31
+ }
app/models/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- """Runtime models used by the inference runner."""
2
-
3
- from .inference import AgentDecision, InferenceConfig
4
-
5
- __all__ = ["AgentDecision", "InferenceConfig"]
 
1
+ """Runtime models used by the inference runner."""
2
+
3
+ from .inference import AgentDecision, InferenceConfig
4
+
5
+ __all__ = ["AgentDecision", "InferenceConfig"]
app/models/inference.py CHANGED
@@ -1,44 +1,44 @@
1
- """Dataclasses shared by the inference runtime."""
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
- from dataclasses import dataclass
7
-
8
-
9
- DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
10
- DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
11
- DEFAULT_BENCHMARK_NAME = "python_code_review_env"
12
-
13
-
14
- @dataclass(slots=True)
15
- class InferenceConfig:
16
- """Runtime configuration loaded from environment variables."""
17
-
18
- api_base_url: str
19
- model_name: str
20
- hf_token: str
21
- benchmark_name: str = DEFAULT_BENCHMARK_NAME
22
- request_timeout_s: float = 12.0
23
- max_retries: int = 2
24
- max_episode_steps: int = 12
25
- success_threshold: float = 0.94
26
-
27
- @classmethod
28
- def from_env(cls) -> "InferenceConfig":
29
- return cls(
30
- api_base_url=str(os.getenv("API_BASE_URL") or DEFAULT_API_BASE_URL),
31
- model_name=str(os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME),
32
- hf_token=str(os.getenv("HF_TOKEN") or ""),
33
- benchmark_name=str(os.getenv("OPENENV_BENCHMARK") or DEFAULT_BENCHMARK_NAME),
34
- )
35
-
36
-
37
- @dataclass(slots=True)
38
- class AgentDecision:
39
- """Validated action chosen for the next environment step."""
40
-
41
- action_type: str
42
- code: str | None = None
43
- source: str = "deterministic"
44
- error: str | None = None
 
1
+ """Dataclasses shared by the inference runtime."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass
7
+
8
+
9
+ DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
10
+ DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
11
+ DEFAULT_BENCHMARK_NAME = "python_code_review_env"
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class InferenceConfig:
16
+ """Runtime configuration loaded from environment variables."""
17
+
18
+ api_base_url: str
19
+ model_name: str
20
+ hf_token: str
21
+ benchmark_name: str = DEFAULT_BENCHMARK_NAME
22
+ request_timeout_s: float = 12.0
23
+ max_retries: int = 2
24
+ max_episode_steps: int = 12
25
+ success_threshold: float = 0.94
26
+
27
+ @classmethod
28
+ def from_env(cls) -> "InferenceConfig":
29
+ return cls(
30
+ api_base_url=str(os.getenv("API_BASE_URL") or DEFAULT_API_BASE_URL),
31
+ model_name=str(os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME),
32
+ hf_token=str(os.getenv("HF_TOKEN") or ""),
33
+ benchmark_name=str(os.getenv("OPENENV_BENCHMARK") or DEFAULT_BENCHMARK_NAME),
34
+ )
35
+
36
+
37
+ @dataclass(slots=True)
38
+ class AgentDecision:
39
+ """Validated action chosen for the next environment step."""
40
+
41
+ action_type: str
42
+ code: str | None = None
43
+ source: str = "deterministic"
44
+ error: str | None = None
app/services/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- """LLM service wrappers for inference-time action planning."""
2
-
3
- from .openai_service import OpenAIActionPlanner
4
-
5
- __all__ = ["OpenAIActionPlanner"]
 
1
+ """LLM service wrappers for inference-time action planning."""
2
+
3
+ from .openai_service import OpenAIActionPlanner
4
+
5
+ __all__ = ["OpenAIActionPlanner"]
app/services/openai_service.py CHANGED
@@ -1,84 +1,84 @@
1
- """OpenAI-compatible action planner backed by the Hugging Face router."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- import time
7
- from typing import Any
8
-
9
- from openai import OpenAI
10
-
11
- from app.models.inference import AgentDecision, InferenceConfig
12
- from app.utils.runtime import compact_text, observation_attr, suppress_output
13
-
14
-
15
- ALLOWED_ACTIONS = {"analyze_code", "edit_code", "run_tests", "submit_solution"}
16
-
17
-
18
- class OpenAIActionPlanner:
19
- """Ask an OpenAI-compatible model for the next safe environment action."""
20
-
21
- def __init__(self, config: InferenceConfig) -> None:
22
- self.config = config
23
- self.client = OpenAI(base_url=config.api_base_url, api_key=config.hf_token) if config.hf_token else None
24
-
25
- def propose_action(self, observation: Any) -> AgentDecision:
26
- if self.client is None:
27
- return AgentDecision(action_type="run_tests", source="fallback", error="HF_TOKEN missing")
28
-
29
- prompt = self._build_prompt(observation)
30
- for attempt in range(self.config.max_retries + 1):
31
- try:
32
- with suppress_output():
33
- response = self.client.chat.completions.create(
34
- model=self.config.model_name,
35
- temperature=0,
36
- max_tokens=120,
37
- messages=[
38
- {
39
- "role": "system",
40
- "content": (
41
- "You are a deterministic OpenEnv controller. "
42
- "Return exactly one compact JSON object with keys action_type and rationale. "
43
- "Allowed action_type values: analyze_code, run_tests, submit_solution. "
44
- "Never emit markdown."
45
- ),
46
- },
47
- {"role": "user", "content": prompt},
48
- ],
49
- response_format={"type": "json_object"},
50
- )
51
- message = response.choices[0].message.content or ""
52
- return self._parse_action(message)
53
- except Exception as exc:
54
- if attempt >= self.config.max_retries:
55
- return AgentDecision(
56
- action_type="run_tests",
57
- source="fallback",
58
- error=compact_text(f"{type(exc).__name__}: {exc}", default="LLM failure"),
59
- )
60
- time.sleep(0.2 * (attempt + 1))
61
-
62
- return AgentDecision(action_type="run_tests", source="fallback", error="LLM retries exhausted")
63
-
64
- def _build_prompt(self, observation: Any) -> str:
65
- return (
66
- f"Task ID: {compact_text(observation_attr(observation, 'task_id', ''), default='unknown')}\n"
67
- f"Description: {compact_text(observation_attr(observation, 'task_description', ''), default='none', limit=400)}\n"
68
- f"Current score: {float(observation_attr(observation, 'score', 0.01) or 0.01):.4f}\n"
69
- f"Errors: {compact_text(observation_attr(observation, 'errors', ''), default='none', limit=300)}\n"
70
- f"Test feedback: {compact_text(observation_attr(observation, 'test_results', ''), default='none', limit=300)}\n"
71
- f"Attempts remaining: {int(observation_attr(observation, 'attempts_remaining', 0) or 0)}\n"
72
- "Choose the single best next control action before a deterministic repair policy handles code updates."
73
- )
74
-
75
- def _parse_action(self, content: str) -> AgentDecision:
76
- try:
77
- payload = json.loads(content)
78
- except Exception:
79
- return AgentDecision(action_type="run_tests", source="fallback", error="invalid LLM payload")
80
-
81
- action_type = compact_text(payload.get("action_type"), default="run_tests")
82
- if action_type not in ALLOWED_ACTIONS or action_type == "edit_code":
83
- action_type = "run_tests"
84
- return AgentDecision(action_type=action_type, source="llm")
 
1
+ """OpenAI-compatible action planner backed by the Hugging Face router."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import time
7
+ from typing import Any
8
+
9
+ from openai import OpenAI
10
+
11
+ from app.models.inference import AgentDecision, InferenceConfig
12
+ from app.utils.runtime import compact_text, observation_attr, suppress_output
13
+
14
+
15
+ ALLOWED_ACTIONS = {"analyze_code", "edit_code", "run_tests", "submit_solution"}
16
+
17
+
18
+ class OpenAIActionPlanner:
19
+ """Ask an OpenAI-compatible model for the next safe environment action."""
20
+
21
+ def __init__(self, config: InferenceConfig) -> None:
22
+ self.config = config
23
+ self.client = OpenAI(base_url=config.api_base_url, api_key=config.hf_token) if config.hf_token else None
24
+
25
+ def propose_action(self, observation: Any) -> AgentDecision:
26
+ if self.client is None:
27
+ return AgentDecision(action_type="run_tests", source="fallback", error="HF_TOKEN missing")
28
+
29
+ prompt = self._build_prompt(observation)
30
+ for attempt in range(self.config.max_retries + 1):
31
+ try:
32
+ with suppress_output():
33
+ response = self.client.chat.completions.create(
34
+ model=self.config.model_name,
35
+ temperature=0,
36
+ max_tokens=120,
37
+ messages=[
38
+ {
39
+ "role": "system",
40
+ "content": (
41
+ "You are a deterministic OpenEnv controller. "
42
+ "Return exactly one compact JSON object with keys action_type and rationale. "
43
+ "Allowed action_type values: analyze_code, run_tests, submit_solution. "
44
+ "Never emit markdown."
45
+ ),
46
+ },
47
+ {"role": "user", "content": prompt},
48
+ ],
49
+ response_format={"type": "json_object"},
50
+ )
51
+ message = response.choices[0].message.content or ""
52
+ return self._parse_action(message)
53
+ except Exception as exc:
54
+ if attempt >= self.config.max_retries:
55
+ return AgentDecision(
56
+ action_type="run_tests",
57
+ source="fallback",
58
+ error=compact_text(f"{type(exc).__name__}: {exc}", default="LLM failure"),
59
+ )
60
+ time.sleep(0.2 * (attempt + 1))
61
+
62
+ return AgentDecision(action_type="run_tests", source="fallback", error="LLM retries exhausted")
63
+
64
+ def _build_prompt(self, observation: Any) -> str:
65
+ return (
66
+ f"Task ID: {compact_text(observation_attr(observation, 'task_id', ''), default='unknown')}\n"
67
+ f"Description: {compact_text(observation_attr(observation, 'task_description', ''), default='none', limit=400)}\n"
68
+ f"Current score: {float(observation_attr(observation, 'score', 0.01) or 0.01):.4f}\n"
69
+ f"Errors: {compact_text(observation_attr(observation, 'errors', ''), default='none', limit=300)}\n"
70
+ f"Test feedback: {compact_text(observation_attr(observation, 'test_results', ''), default='none', limit=300)}\n"
71
+ f"Attempts remaining: {int(observation_attr(observation, 'attempts_remaining', 0) or 0)}\n"
72
+ "Choose the single best next control action before a deterministic repair policy handles code updates."
73
+ )
74
+
75
+ def _parse_action(self, content: str) -> AgentDecision:
76
+ try:
77
+ payload = json.loads(content)
78
+ except Exception:
79
+ return AgentDecision(action_type="run_tests", source="fallback", error="invalid LLM payload")
80
+
81
+ action_type = compact_text(payload.get("action_type"), default="run_tests")
82
+ if action_type not in ALLOWED_ACTIONS or action_type == "edit_code":
83
+ action_type = "run_tests"
84
+ return AgentDecision(action_type=action_type, source="llm")
app/streamlit_app.py CHANGED
@@ -1,100 +1,100 @@
1
- """Streamlit frontend for the multi-domain analyzer platform."""
2
-
3
- from __future__ import annotations
4
-
5
- import streamlit as st
6
-
7
- from app.examples import EXAMPLES
8
- from schemas.request import AnalyzeCodeRequest
9
- from services.analysis_service import AnalysisService
10
-
11
-
12
- analysis_service = AnalysisService()
13
-
14
-
15
- def _analyze(code: str, context_window: str, traceback_text: str, domain_hint: str):
16
- """Run the analysis service with validated request payloads."""
17
-
18
- request = AnalyzeCodeRequest(
19
- code=code,
20
- context_window=context_window,
21
- traceback_text=traceback_text,
22
- domain_hint=domain_hint, # type: ignore[arg-type]
23
- )
24
- return analysis_service.analyze(request)
25
-
26
-
27
- def main() -> None:
28
- """Render the Streamlit UI."""
29
-
30
- st.set_page_config(page_title="Multi-Domain AI Code Analyzer", layout="wide")
31
- st.title("Multi-Domain AI Code Analyzer & Improvement System")
32
- st.caption("PyTorch-powered code review across DSA, Data Science, ML/DL, and Web backend code.")
33
-
34
- example_name = st.selectbox("Example input", list(EXAMPLES.keys()))
35
- example = EXAMPLES[example_name]
36
- auto_analyze = st.toggle("Real-time scoring", value=True)
37
-
38
- left, right = st.columns([1.2, 1.0])
39
- with left:
40
- code = st.text_area("Code input", value=example["code"], height=420)
41
- context_window = st.text_area("Context window", value=example["context_window"], height=100)
42
- traceback_text = st.text_area("Optional traceback / runtime hint", value=example["traceback_text"], height=100)
43
- domain_hint = st.selectbox("Domain hint", ["auto", "dsa", "data_science", "ml_dl", "web"], index=["auto", "dsa", "data_science", "ml_dl", "web"].index(example["domain_hint"]))
44
- analyze_clicked = st.button("Analyze Code", type="primary")
45
-
46
- result = None
47
- if code and (analyze_clicked or auto_analyze):
48
- result = _analyze(code, context_window, traceback_text, domain_hint)
49
-
50
- with right:
51
- if result is None:
52
- st.info("Paste code or load an example to start analysis.")
53
- else:
54
- metric_cols = st.columns(4)
55
- metric_cols[0].metric("Detected domain", result.detected_domain)
56
- metric_cols[1].metric("ML score", f"{result.score_breakdown.ml_score:.0%}")
57
- metric_cols[2].metric("Domain score", f"{result.score_breakdown.domain_score:.0%}")
58
- metric_cols[3].metric("Reward", f"{result.score_breakdown.reward:.0%}")
59
- st.bar_chart(result.domain_confidences)
60
- st.caption(result.summary)
61
-
62
- if result is not None:
63
- overview_tab, suggestions_tab, domain_tab, static_tab = st.tabs(
64
- ["Overview", "Suggestions", "Domain Detail", "Static Analysis"]
65
- )
66
-
67
- with overview_tab:
68
- st.subheader("Improvement Plan")
69
- for step in result.improvement_plan:
70
- st.write(f"- {step}")
71
- st.subheader("Complexity")
72
- st.write(
73
- {
74
- "time_complexity": result.static_analysis.time_complexity,
75
- "space_complexity": result.static_analysis.space_complexity,
76
- "cyclomatic_complexity": result.static_analysis.cyclomatic_complexity,
77
- }
78
- )
79
-
80
- with suggestions_tab:
81
- st.subheader("Suggestions")
82
- for suggestion in result.domain_analysis.suggestions:
83
- st.write(f"- {suggestion}")
84
- if result.domain_analysis.issues:
85
- st.subheader("Issues")
86
- for issue in result.domain_analysis.issues:
87
- st.write(f"- [{issue.severity}] {issue.title}: {issue.description}")
88
-
89
- with domain_tab:
90
- st.subheader("Domain Highlights")
91
- st.json(result.domain_analysis.highlights)
92
- st.write(f"Domain score: {result.domain_analysis.domain_score:.0%}")
93
-
94
- with static_tab:
95
- st.subheader("Static Analysis")
96
- st.json(result.static_analysis.model_dump())
97
-
98
-
99
- if __name__ == "__main__":
100
- main()
 
1
+ """Streamlit frontend for the multi-domain analyzer platform."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import streamlit as st
6
+
7
+ from app.examples import EXAMPLES
8
+ from schemas.request import AnalyzeCodeRequest
9
+ from services.analysis_service import AnalysisService
10
+
11
+
12
+ analysis_service = AnalysisService()
13
+
14
+
15
+ def _analyze(code: str, context_window: str, traceback_text: str, domain_hint: str):
16
+ """Run the analysis service with validated request payloads."""
17
+
18
+ request = AnalyzeCodeRequest(
19
+ code=code,
20
+ context_window=context_window,
21
+ traceback_text=traceback_text,
22
+ domain_hint=domain_hint, # type: ignore[arg-type]
23
+ )
24
+ return analysis_service.analyze(request)
25
+
26
+
27
+ def main() -> None:
28
+ """Render the Streamlit UI."""
29
+
30
+ st.set_page_config(page_title="Multi-Domain AI Code Analyzer", layout="wide")
31
+ st.title("Multi-Domain AI Code Analyzer & Improvement System")
32
+ st.caption("PyTorch-powered code review across DSA, Data Science, ML/DL, and Web backend code.")
33
+
34
+ example_name = st.selectbox("Example input", list(EXAMPLES.keys()))
35
+ example = EXAMPLES[example_name]
36
+ auto_analyze = st.toggle("Real-time scoring", value=True)
37
+
38
+ left, right = st.columns([1.2, 1.0])
39
+ with left:
40
+ code = st.text_area("Code input", value=example["code"], height=420)
41
+ context_window = st.text_area("Context window", value=example["context_window"], height=100)
42
+ traceback_text = st.text_area("Optional traceback / runtime hint", value=example["traceback_text"], height=100)
43
+ domain_hint = st.selectbox("Domain hint", ["auto", "dsa", "data_science", "ml_dl", "web"], index=["auto", "dsa", "data_science", "ml_dl", "web"].index(example["domain_hint"]))
44
+ analyze_clicked = st.button("Analyze Code", type="primary")
45
+
46
+ result = None
47
+ if code and (analyze_clicked or auto_analyze):
48
+ result = _analyze(code, context_window, traceback_text, domain_hint)
49
+
50
+ with right:
51
+ if result is None:
52
+ st.info("Paste code or load an example to start analysis.")
53
+ else:
54
+ metric_cols = st.columns(4)
55
+ metric_cols[0].metric("Detected domain", result.detected_domain)
56
+ metric_cols[1].metric("ML score", f"{result.score_breakdown.ml_score:.0%}")
57
+ metric_cols[2].metric("Domain score", f"{result.score_breakdown.domain_score:.0%}")
58
+ metric_cols[3].metric("Reward", f"{result.score_breakdown.reward:.0%}")
59
+ st.bar_chart(result.domain_confidences)
60
+ st.caption(result.summary)
61
+
62
+ if result is not None:
63
+ overview_tab, suggestions_tab, domain_tab, static_tab = st.tabs(
64
+ ["Overview", "Suggestions", "Domain Detail", "Static Analysis"]
65
+ )
66
+
67
+ with overview_tab:
68
+ st.subheader("Improvement Plan")
69
+ for step in result.improvement_plan:
70
+ st.write(f"- {step}")
71
+ st.subheader("Complexity")
72
+ st.write(
73
+ {
74
+ "time_complexity": result.static_analysis.time_complexity,
75
+ "space_complexity": result.static_analysis.space_complexity,
76
+ "cyclomatic_complexity": result.static_analysis.cyclomatic_complexity,
77
+ }
78
+ )
79
+
80
+ with suggestions_tab:
81
+ st.subheader("Suggestions")
82
+ for suggestion in result.domain_analysis.suggestions:
83
+ st.write(f"- {suggestion}")
84
+ if result.domain_analysis.issues:
85
+ st.subheader("Issues")
86
+ for issue in result.domain_analysis.issues:
87
+ st.write(f"- [{issue.severity}] {issue.title}: {issue.description}")
88
+
89
+ with domain_tab:
90
+ st.subheader("Domain Highlights")
91
+ st.json(result.domain_analysis.highlights)
92
+ st.write(f"Domain score: {result.domain_analysis.domain_score:.0%}")
93
+
94
+ with static_tab:
95
+ st.subheader("Static Analysis")
96
+ st.json(result.static_analysis.model_dump())
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
app/utils/__init__.py CHANGED
@@ -1,21 +1,21 @@
1
- """Utility helpers shared by the inference runtime."""
2
-
3
- from .runtime import (
4
- compact_text,
5
- format_bool,
6
- format_error,
7
- format_reward,
8
- observation_attr,
9
- parse_task_ids,
10
- suppress_output,
11
- )
12
-
13
- __all__ = [
14
- "compact_text",
15
- "format_bool",
16
- "format_error",
17
- "format_reward",
18
- "observation_attr",
19
- "parse_task_ids",
20
- "suppress_output",
21
- ]
 
1
+ """Utility helpers shared by the inference runtime."""
2
+
3
+ from .runtime import (
4
+ compact_text,
5
+ format_bool,
6
+ format_error,
7
+ format_reward,
8
+ observation_attr,
9
+ parse_task_ids,
10
+ suppress_output,
11
+ )
12
+
13
+ __all__ = [
14
+ "compact_text",
15
+ "format_bool",
16
+ "format_error",
17
+ "format_reward",
18
+ "observation_attr",
19
+ "parse_task_ids",
20
+ "suppress_output",
21
+ ]
app/utils/runtime.py CHANGED
@@ -1,95 +1,95 @@
1
- """Formatting, parsing, and IO-suppression helpers for inference."""
2
-
3
- from __future__ import annotations
4
-
5
- import io
6
- from collections.abc import Iterable
7
- from contextlib import contextmanager, redirect_stderr, redirect_stdout
8
- from typing import Any, Iterator
9
-
10
- try:
11
- from tasks import task_ids
12
- except ImportError: # pragma: no cover
13
- from python_env.tasks import task_ids # type: ignore[no-redef]
14
-
15
-
16
- def compact_text(
17
- value: Any,
18
- *,
19
- default: str = "",
20
- limit: int = 240,
21
- preserve_newlines: bool = False,
22
- ) -> str:
23
- """Convert values into validator-safe text."""
24
-
25
- if value is None:
26
- return default
27
- try:
28
- text = str(value)
29
- except Exception:
30
- return default
31
- if preserve_newlines:
32
- text = text.strip()
33
- else:
34
- text = " ".join(text.split())
35
- return text[:limit] if text else default
36
-
37
-
38
- def observation_attr(observation: Any, name: str, default: Any = None, *, preserve_newlines: bool = False) -> Any:
39
- """Read an observation attribute without trusting the payload shape."""
40
-
41
- if isinstance(observation, dict):
42
- value = observation.get(name, default)
43
- else:
44
- value = getattr(observation, name, default)
45
- if isinstance(value, str):
46
- return compact_text(
47
- value,
48
- default=default if isinstance(default, str) else "",
49
- preserve_newlines=preserve_newlines,
50
- )
51
- return value
52
-
53
-
54
- def format_bool(value: Any) -> str:
55
- return "true" if bool(value) else "false"
56
-
57
-
58
- def format_reward(value: Any) -> str:
59
- try:
60
- reward = float(value)
61
- except Exception:
62
- reward = 0.0
63
- return f"{reward:.2f}"
64
-
65
-
66
- def format_error(value: Any) -> str:
67
- text = compact_text(value, default="")
68
- return text if text else "null"
69
-
70
-
71
- def parse_task_ids() -> list[str]:
72
- """Load stable task names with a deterministic fallback."""
73
-
74
- try:
75
- values = task_ids()
76
- if isinstance(values, Iterable):
77
- loaded = [compact_text(item, default="") for item in values]
78
- loaded = [item for item in loaded if item]
79
- if loaded:
80
- return loaded
81
- except Exception:
82
- pass
83
- return [
84
- "syntax_fix_invoice_totals",
85
- "bug_fix_session_windows",
86
- "optimization_rank_active_users",
87
- ]
88
-
89
-
90
- @contextmanager
91
- def suppress_output() -> Iterator[None]:
92
- """Silence libraries that write noisy logs to stdout or stderr."""
93
-
94
- with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
95
- yield
 
1
+ """Formatting, parsing, and IO-suppression helpers for inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import io
6
+ from collections.abc import Iterable
7
+ from contextlib import contextmanager, redirect_stderr, redirect_stdout
8
+ from typing import Any, Iterator
9
+
10
+ try:
11
+ from tasks import task_ids
12
+ except ImportError: # pragma: no cover
13
+ from python_env.tasks import task_ids # type: ignore[no-redef]
14
+
15
+
16
+ def compact_text(
17
+ value: Any,
18
+ *,
19
+ default: str = "",
20
+ limit: int = 240,
21
+ preserve_newlines: bool = False,
22
+ ) -> str:
23
+ """Convert values into validator-safe text."""
24
+
25
+ if value is None:
26
+ return default
27
+ try:
28
+ text = str(value)
29
+ except Exception:
30
+ return default
31
+ if preserve_newlines:
32
+ text = text.strip()
33
+ else:
34
+ text = " ".join(text.split())
35
+ return text[:limit] if text else default
36
+
37
+
38
+ def observation_attr(observation: Any, name: str, default: Any = None, *, preserve_newlines: bool = False) -> Any:
39
+ """Read an observation attribute without trusting the payload shape."""
40
+
41
+ if isinstance(observation, dict):
42
+ value = observation.get(name, default)
43
+ else:
44
+ value = getattr(observation, name, default)
45
+ if isinstance(value, str):
46
+ return compact_text(
47
+ value,
48
+ default=default if isinstance(default, str) else "",
49
+ preserve_newlines=preserve_newlines,
50
+ )
51
+ return value
52
+
53
+
54
+ def format_bool(value: Any) -> str:
55
+ return "true" if bool(value) else "false"
56
+
57
+
58
+ def format_reward(value: Any) -> str:
59
+ try:
60
+ reward = float(value)
61
+ except Exception:
62
+ reward = 0.0
63
+ return f"{reward:.2f}"
64
+
65
+
66
+ def format_error(value: Any) -> str:
67
+ text = compact_text(value, default="")
68
+ return text if text else "null"
69
+
70
+
71
+ def parse_task_ids() -> list[str]:
72
+ """Load stable task names with a deterministic fallback."""
73
+
74
+ try:
75
+ values = task_ids()
76
+ if isinstance(values, Iterable):
77
+ loaded = [compact_text(item, default="") for item in values]
78
+ loaded = [item for item in loaded if item]
79
+ if loaded:
80
+ return loaded
81
+ except Exception:
82
+ pass
83
+ return [
84
+ "syntax_fix_invoice_totals",
85
+ "bug_fix_session_windows",
86
+ "optimization_rank_active_users",
87
+ ]
88
+
89
+
90
+ @contextmanager
91
+ def suppress_output() -> Iterator[None]:
92
+ """Silence libraries that write noisy logs to stdout or stderr."""
93
+
94
+ with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
95
+ yield
client.py CHANGED
@@ -7,11 +7,11 @@ from typing import Dict
7
  from openenv.core import EnvClient
8
  from openenv.core.client_types import StepResult
9
 
10
- from .models import (
11
- PythonCodeReviewAction,
12
- PythonCodeReviewObservation,
13
- PythonCodeReviewState,
14
- )
15
 
16
 
17
  class PythonCodeReviewEnv(
 
7
  from openenv.core import EnvClient
8
  from openenv.core.client_types import StepResult
9
 
10
+ from .models import (
11
+ PythonCodeReviewAction,
12
+ PythonCodeReviewObservation,
13
+ PythonCodeReviewState,
14
+ )
15
 
16
 
17
  class PythonCodeReviewEnv(
graders/bug_fix.py CHANGED
@@ -3,10 +3,10 @@
3
  from __future__ import annotations
4
 
5
  try:
6
- from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
- from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .shared import (
 
3
  from __future__ import annotations
4
 
5
  try:
6
+ from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
+ from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .shared import (
graders/dispatch.py CHANGED
@@ -3,10 +3,10 @@
3
  from __future__ import annotations
4
 
5
  try:
6
- from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
- from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .bug_fix import grade_bug_fix_task
 
3
  from __future__ import annotations
4
 
5
  try:
6
+ from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
+ from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .bug_fix import grade_bug_fix_task
graders/optimization.py CHANGED
@@ -3,10 +3,10 @@
3
  from __future__ import annotations
4
 
5
  try:
6
- from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
- from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .shared import (
 
3
  from __future__ import annotations
4
 
5
  try:
6
+ from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
+ from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .shared import (
graders/shared.py CHANGED
@@ -2,20 +2,20 @@
2
 
3
  from __future__ import annotations
4
 
5
- import ast
6
- import difflib
7
- import math
8
- import multiprocessing as mp
9
- import os
10
- import time
11
- import traceback
12
  from typing import Any, Callable, Dict, List
13
 
14
  try:
15
- from ..models import TaskGrade
16
  from ..tasks.catalog import CallCase, ReviewTask
17
  except ImportError:
18
- from models import TaskGrade
19
  from tasks.catalog import CallCase, ReviewTask
20
 
21
 
@@ -121,11 +121,11 @@ def _queue_worker(
121
  )
122
 
123
 
124
- def run_with_timeout(
125
- worker: Callable[[Dict[str, Any]], Dict[str, Any]],
126
- payload: Dict[str, Any],
127
- timeout_s: float,
128
- ) -> Dict[str, Any]:
129
  """Execute a worker in a subprocess and terminate on timeout."""
130
 
131
  ctx = mp.get_context("spawn")
@@ -146,31 +146,31 @@ def run_with_timeout(
146
  if not message["ok"]:
147
  return {
148
  "timed_out": False,
149
- "error": f"{message['error']}\n{message['traceback']}",
150
- }
151
- return {"timed_out": False, "data": message["data"]}
152
-
153
-
154
- def run_inline_with_timeout(
155
- worker: Callable[[Dict[str, Any]], Dict[str, Any]],
156
- payload: Dict[str, Any],
157
- timeout_s: float,
158
- ) -> Dict[str, Any]:
159
- """Fallback execution path for platforms where spawned workers are unreliable."""
160
-
161
- started = time.perf_counter()
162
- try:
163
- data = worker(payload)
164
- except Exception as exc:
165
- return {
166
- "timed_out": False,
167
- "error": f"{type(exc).__name__}: {exc}\n{traceback.format_exc(limit=5)}",
168
- }
169
-
170
- elapsed = time.perf_counter() - started
171
- if elapsed > timeout_s:
172
- return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."}
173
- return {"timed_out": False, "data": data}
174
 
175
 
176
  def _execute_cases_worker(payload: Dict[str, Any]) -> Dict[str, Any]:
@@ -375,7 +375,7 @@ def _benchmark_worker(payload: Dict[str, Any]) -> Dict[str, Any]:
375
  return {"baseline_seconds": baseline_seconds, "candidate_seconds": candidate_seconds}
376
 
377
 
378
- def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[str, Any]:
379
  """Benchmark a candidate solution against the starter implementation."""
380
 
381
  if not task.benchmark_config:
@@ -389,10 +389,10 @@ def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[s
389
  "events": events,
390
  "iterations": task.benchmark_config.get("iterations", 5),
391
  }
392
- if os.name == "nt":
393
- result = run_inline_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s)
394
- else:
395
- result = run_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s)
396
  if result.get("timed_out"):
397
  return {"runtime_score": component_score(STRICT_SCORE_MIN), "timed_out": True, "details": result["error"]}
398
  if "error" in result:
 
2
 
3
  from __future__ import annotations
4
 
5
+ import ast
6
+ import difflib
7
+ import math
8
+ import multiprocessing as mp
9
+ import os
10
+ import time
11
+ import traceback
12
  from typing import Any, Callable, Dict, List
13
 
14
  try:
15
+ from ..models import TaskGrade
16
  from ..tasks.catalog import CallCase, ReviewTask
17
  except ImportError:
18
+ from models import TaskGrade
19
  from tasks.catalog import CallCase, ReviewTask
20
 
21
 
 
121
  )
122
 
123
 
124
+ def run_with_timeout(
125
+ worker: Callable[[Dict[str, Any]], Dict[str, Any]],
126
+ payload: Dict[str, Any],
127
+ timeout_s: float,
128
+ ) -> Dict[str, Any]:
129
  """Execute a worker in a subprocess and terminate on timeout."""
130
 
131
  ctx = mp.get_context("spawn")
 
146
  if not message["ok"]:
147
  return {
148
  "timed_out": False,
149
+ "error": f"{message['error']}\n{message['traceback']}",
150
+ }
151
+ return {"timed_out": False, "data": message["data"]}
152
+
153
+
154
+ def run_inline_with_timeout(
155
+ worker: Callable[[Dict[str, Any]], Dict[str, Any]],
156
+ payload: Dict[str, Any],
157
+ timeout_s: float,
158
+ ) -> Dict[str, Any]:
159
+ """Fallback execution path for platforms where spawned workers are unreliable."""
160
+
161
+ started = time.perf_counter()
162
+ try:
163
+ data = worker(payload)
164
+ except Exception as exc:
165
+ return {
166
+ "timed_out": False,
167
+ "error": f"{type(exc).__name__}: {exc}\n{traceback.format_exc(limit=5)}",
168
+ }
169
+
170
+ elapsed = time.perf_counter() - started
171
+ if elapsed > timeout_s:
172
+ return {"timed_out": True, "error": f"Execution exceeded {timeout_s:.1f}s timeout."}
173
+ return {"timed_out": False, "data": data}
174
 
175
 
176
  def _execute_cases_worker(payload: Dict[str, Any]) -> Dict[str, Any]:
 
375
  return {"baseline_seconds": baseline_seconds, "candidate_seconds": candidate_seconds}
376
 
377
 
378
+ def benchmark_candidate(task: ReviewTask, code: str, timeout_s: float) -> Dict[str, Any]:
379
  """Benchmark a candidate solution against the starter implementation."""
380
 
381
  if not task.benchmark_config:
 
389
  "events": events,
390
  "iterations": task.benchmark_config.get("iterations", 5),
391
  }
392
+ if os.name == "nt":
393
+ result = run_inline_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s)
394
+ else:
395
+ result = run_with_timeout(_benchmark_worker, payload, timeout_s=timeout_s)
396
  if result.get("timed_out"):
397
  return {"runtime_score": component_score(STRICT_SCORE_MIN), "timed_out": True, "details": result["error"]}
398
  if "error" in result:
graders/syntax.py CHANGED
@@ -3,10 +3,10 @@
3
  from __future__ import annotations
4
 
5
  try:
6
- from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
- from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .shared import (
 
3
  from __future__ import annotations
4
 
5
  try:
6
+ from ..models import TaskGrade
7
  from ..tasks.catalog import ReviewTask
8
  except ImportError:
9
+ from models import TaskGrade
10
  from tasks.catalog import ReviewTask
11
 
12
  from .shared import (
inference.py CHANGED
@@ -1,12 +1,12 @@
1
- #!/usr/bin/env python3
2
- """Root validator entrypoint."""
3
-
4
- from __future__ import annotations
5
-
6
- import sys
7
-
8
- from app.env.runner import main
9
-
10
-
11
- if __name__ == "__main__":
12
- sys.exit(main())
 
1
+ #!/usr/bin/env python3
2
+ """Root validator entrypoint."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import sys
7
+
8
+ from app.env.runner import main
9
+
10
+
11
+ if __name__ == "__main__":
12
+ sys.exit(main())
launch.py CHANGED
@@ -1,35 +1,35 @@
1
- """Launch the FastAPI backend and Streamlit UI in one Docker container."""
2
-
3
- from __future__ import annotations
4
-
5
- import subprocess
6
- import sys
7
-
8
-
9
- def main() -> int:
10
- """Start the API backend in the background and keep Streamlit in the foreground."""
11
-
12
- api_process = subprocess.Popen(
13
- ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8001"],
14
- )
15
- try:
16
- return subprocess.call(
17
- [
18
- "streamlit",
19
- "run",
20
- "app/streamlit_app.py",
21
- "--server.port",
22
- "8000",
23
- "--server.address",
24
- "0.0.0.0",
25
- "--server.headless",
26
- "true",
27
- ]
28
- )
29
- finally:
30
- api_process.terminate()
31
- api_process.wait(timeout=10)
32
-
33
-
34
- if __name__ == "__main__":
35
- sys.exit(main())
 
1
+ """Launch the FastAPI backend and Streamlit UI in one Docker container."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import subprocess
6
+ import sys
7
+
8
+
9
+ def main() -> int:
10
+ """Start the API backend in the background and keep Streamlit in the foreground."""
11
+
12
+ api_process = subprocess.Popen(
13
+ ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8001"],
14
+ )
15
+ try:
16
+ return subprocess.call(
17
+ [
18
+ "streamlit",
19
+ "run",
20
+ "app/streamlit_app.py",
21
+ "--server.port",
22
+ "8000",
23
+ "--server.address",
24
+ "0.0.0.0",
25
+ "--server.headless",
26
+ "true",
27
+ ]
28
+ )
29
+ finally:
30
+ api_process.terminate()
31
+ api_process.wait(timeout=10)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ sys.exit(main())
models.py CHANGED
@@ -1,4 +1,4 @@
1
- """Typed models for the python_code_review_env environment."""
2
 
3
  from __future__ import annotations
4
 
@@ -23,22 +23,22 @@ class HistoryEntry(BaseModel):
23
  reward: float = Field(..., gt=0.0, lt=1.0, description="Reward returned for the step.")
24
 
25
 
26
- class RewardDetails(BaseModel):
27
- """Transparent reward decomposition for debugging and training."""
28
-
29
- value: float = Field(..., gt=0.0, lt=1.0, description="Clamped net reward in (0.0, 1.0).")
30
- syntax_reward: float = Field(default=0.0)
31
- test_reward: float = Field(default=0.0)
32
- correctness_bonus: float = Field(default=0.0)
33
- quality_bonus: float = Field(default=0.0)
34
- error_reduction_bonus: float = Field(default=0.0)
35
- completion_bonus: float = Field(default=0.0)
36
- runtime_bonus: float = Field(default=0.0)
37
- progress_delta: float = Field(default=0.0)
38
- invalid_action_penalty: float = Field(default=0.0)
39
- timeout_penalty: float = Field(default=0.0)
40
- regression_penalty: float = Field(default=0.0)
41
- stagnation_penalty: float = Field(default=0.0)
42
  reason: str = Field(..., description="Human-readable reward explanation.")
43
  prev_score: float = Field(default=0.01, gt=0.0, lt=1.0)
44
  curr_score: float = Field(default=0.01, gt=0.0, lt=1.0)
@@ -66,17 +66,17 @@ class PythonCodeReviewObservation(Observation):
66
  current_code: str = Field(..., description="Latest code under review.")
67
  errors: str = Field(default="", description="Syntax or execution errors.")
68
  test_results: str = Field(default="", description="Public test and benchmark feedback.")
69
- visible_tests: List[str] = Field(default_factory=list)
70
- history: List[HistoryEntry] = Field(default_factory=list)
71
- attempts_remaining: int = Field(..., ge=0)
72
- last_action_status: str = Field(default="")
73
- last_action_error: Optional[str] = Field(default=None)
74
- score: float = Field(..., gt=0.0, lt=1.0)
75
- reward: float = Field(default=0.1, gt=0.0, lt=1.0)
76
- done: bool = Field(default=False)
77
- reward_details: RewardDetails = Field(
78
- default_factory=lambda: RewardDetails(value=0.1, reason="Environment reset.")
79
- )
80
 
81
 
82
  class PythonCodeReviewState(State):
 
1
+ """Typed models for the python_code_review_env environment."""
2
 
3
  from __future__ import annotations
4
 
 
23
  reward: float = Field(..., gt=0.0, lt=1.0, description="Reward returned for the step.")
24
 
25
 
26
+ class RewardDetails(BaseModel):
27
+ """Transparent reward decomposition for debugging and training."""
28
+
29
+ value: float = Field(..., gt=0.0, lt=1.0, description="Clamped net reward in (0.0, 1.0).")
30
+ syntax_reward: float = Field(default=0.0)
31
+ test_reward: float = Field(default=0.0)
32
+ correctness_bonus: float = Field(default=0.0)
33
+ quality_bonus: float = Field(default=0.0)
34
+ error_reduction_bonus: float = Field(default=0.0)
35
+ completion_bonus: float = Field(default=0.0)
36
+ runtime_bonus: float = Field(default=0.0)
37
+ progress_delta: float = Field(default=0.0)
38
+ invalid_action_penalty: float = Field(default=0.0)
39
+ timeout_penalty: float = Field(default=0.0)
40
+ regression_penalty: float = Field(default=0.0)
41
+ stagnation_penalty: float = Field(default=0.0)
42
  reason: str = Field(..., description="Human-readable reward explanation.")
43
  prev_score: float = Field(default=0.01, gt=0.0, lt=1.0)
44
  curr_score: float = Field(default=0.01, gt=0.0, lt=1.0)
 
66
  current_code: str = Field(..., description="Latest code under review.")
67
  errors: str = Field(default="", description="Syntax or execution errors.")
68
  test_results: str = Field(default="", description="Public test and benchmark feedback.")
69
+ visible_tests: List[str] = Field(default_factory=list)
70
+ history: List[HistoryEntry] = Field(default_factory=list)
71
+ attempts_remaining: int = Field(..., ge=0)
72
+ last_action_status: str = Field(default="")
73
+ last_action_error: Optional[str] = Field(default=None)
74
+ score: float = Field(..., gt=0.0, lt=1.0)
75
+ reward: float = Field(default=0.1, gt=0.0, lt=1.0)
76
+ done: bool = Field(default=False)
77
+ reward_details: RewardDetails = Field(
78
+ default_factory=lambda: RewardDetails(value=0.1, reason="Environment reset.")
79
+ )
80
 
81
 
82
  class PythonCodeReviewState(State):
models/__init__.py CHANGED
@@ -1,66 +1,66 @@
1
- """PyTorch-backed model wrappers plus OpenEnv schema exports."""
2
-
3
- from __future__ import annotations
4
-
5
- import importlib.util
6
- import sys
7
- from pathlib import Path
8
-
9
- from .pytorch_model import PyTorchCodeAnalyzerModel
10
-
11
-
12
- def _load_schema_module():
13
- schema_path = Path(__file__).resolve().parent.parent / "models.py"
14
- spec = importlib.util.spec_from_file_location("_python_env_schema_models", schema_path)
15
- if spec is None or spec.loader is None: # pragma: no cover
16
- raise ImportError(f"Unable to load schema models from {schema_path}")
17
- if spec.name in sys.modules:
18
- return sys.modules[spec.name]
19
- module = importlib.util.module_from_spec(spec)
20
- sys.modules[spec.name] = module
21
- spec.loader.exec_module(module)
22
- for model_name in (
23
- "HistoryEntry",
24
- "RewardDetails",
25
- "PythonCodeReviewAction",
26
- "PythonCodeReviewObservation",
27
- "PythonCodeReviewState",
28
- "TaskDescriptor",
29
- "TaskSummary",
30
- "TaskGrade",
31
- "HealthResponse",
32
- ):
33
- getattr(module, model_name).model_rebuild()
34
- return module
35
-
36
-
37
- _schema_models = _load_schema_module()
38
-
39
- HealthResponse = _schema_models.HealthResponse
40
- HistoryEntry = _schema_models.HistoryEntry
41
- PythonAction = _schema_models.PythonAction
42
- PythonCodeReviewAction = _schema_models.PythonCodeReviewAction
43
- PythonCodeReviewObservation = _schema_models.PythonCodeReviewObservation
44
- PythonCodeReviewState = _schema_models.PythonCodeReviewState
45
- PythonObservation = _schema_models.PythonObservation
46
- PythonState = _schema_models.PythonState
47
- RewardDetails = _schema_models.RewardDetails
48
- TaskDescriptor = _schema_models.TaskDescriptor
49
- TaskGrade = _schema_models.TaskGrade
50
- TaskSummary = _schema_models.TaskSummary
51
-
52
- __all__ = [
53
- "HealthResponse",
54
- "HistoryEntry",
55
- "PyTorchCodeAnalyzerModel",
56
- "PythonAction",
57
- "PythonCodeReviewAction",
58
- "PythonCodeReviewObservation",
59
- "PythonCodeReviewState",
60
- "PythonObservation",
61
- "PythonState",
62
- "RewardDetails",
63
- "TaskDescriptor",
64
- "TaskGrade",
65
- "TaskSummary",
66
- ]
 
1
+ """PyTorch-backed model wrappers plus OpenEnv schema exports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib.util
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ from .pytorch_model import PyTorchCodeAnalyzerModel
10
+
11
+
12
+ def _load_schema_module():
13
+ schema_path = Path(__file__).resolve().parent.parent / "models.py"
14
+ spec = importlib.util.spec_from_file_location("_python_env_schema_models", schema_path)
15
+ if spec is None or spec.loader is None: # pragma: no cover
16
+ raise ImportError(f"Unable to load schema models from {schema_path}")
17
+ if spec.name in sys.modules:
18
+ return sys.modules[spec.name]
19
+ module = importlib.util.module_from_spec(spec)
20
+ sys.modules[spec.name] = module
21
+ spec.loader.exec_module(module)
22
+ for model_name in (
23
+ "HistoryEntry",
24
+ "RewardDetails",
25
+ "PythonCodeReviewAction",
26
+ "PythonCodeReviewObservation",
27
+ "PythonCodeReviewState",
28
+ "TaskDescriptor",
29
+ "TaskSummary",
30
+ "TaskGrade",
31
+ "HealthResponse",
32
+ ):
33
+ getattr(module, model_name).model_rebuild()
34
+ return module
35
+
36
+
37
+ _schema_models = _load_schema_module()
38
+
39
+ HealthResponse = _schema_models.HealthResponse
40
+ HistoryEntry = _schema_models.HistoryEntry
41
+ PythonAction = _schema_models.PythonAction
42
+ PythonCodeReviewAction = _schema_models.PythonCodeReviewAction
43
+ PythonCodeReviewObservation = _schema_models.PythonCodeReviewObservation
44
+ PythonCodeReviewState = _schema_models.PythonCodeReviewState
45
+ PythonObservation = _schema_models.PythonObservation
46
+ PythonState = _schema_models.PythonState
47
+ RewardDetails = _schema_models.RewardDetails
48
+ TaskDescriptor = _schema_models.TaskDescriptor
49
+ TaskGrade = _schema_models.TaskGrade
50
+ TaskSummary = _schema_models.TaskSummary
51
+
52
+ __all__ = [
53
+ "HealthResponse",
54
+ "HistoryEntry",
55
+ "PyTorchCodeAnalyzerModel",
56
+ "PythonAction",
57
+ "PythonCodeReviewAction",
58
+ "PythonCodeReviewObservation",
59
+ "PythonCodeReviewState",
60
+ "PythonObservation",
61
+ "PythonState",
62
+ "RewardDetails",
63
+ "TaskDescriptor",
64
+ "TaskGrade",
65
+ "TaskSummary",
66
+ ]
models/pytorch_model.py CHANGED
@@ -1,149 +1,149 @@
1
- """PyTorch + transformers model wrapper for multi-domain code scoring."""
2
-
3
- from __future__ import annotations
4
-
5
- import hashlib
6
- from typing import Dict, List, Sequence
7
-
8
- import torch
9
- import torch.nn.functional as F
10
-
11
- try:
12
- from transformers import AutoModel, AutoTokenizer
13
- except Exception:
14
- AutoModel = None # type: ignore[assignment]
15
- AutoTokenizer = None # type: ignore[assignment]
16
-
17
-
18
- DOMAIN_PROTOTYPES: Dict[str, List[str]] = {
19
- "dsa": [
20
- "Binary search, hashmap optimization, recursion, dynamic programming, arrays, trees, graphs, stack, queue, complexity.",
21
- "Competitive programming algorithm with loops, memoization, prefix sums, and asymptotic analysis.",
22
- ],
23
- "data_science": [
24
- "Pandas dataframe transformation, numpy vectorization, feature leakage, train test split, iterrows misuse.",
25
- "Data cleaning pipeline using pandas, numpy, aggregation, joins, and vectorized operations.",
26
- ],
27
- "ml_dl": [
28
- "PyTorch model, training loop, optimizer, backward pass, eval mode, no_grad, loss function, dataloader.",
29
- "Machine learning inference and training code with torch, sklearn, tensors, gradients, and model checkpoints.",
30
- ],
31
- "web": [
32
- "FastAPI endpoint, request validation, Pydantic models, async routes, API security, backend service design.",
33
- "REST API backend with routers, dependency injection, input validation, serialization, and error handling.",
34
- ],
35
- "general": [
36
- "General Python utility code with readable structure, typing, tests, and maintainable abstractions.",
37
- ],
38
- }
39
-
40
- QUALITY_ANCHORS: Dict[str, List[str]] = {
41
- "high": [
42
- "Readable typed Python code with validation, efficient algorithms, vectorized operations, safe inference, and clean API boundaries.",
43
- "Production-ready code with small functions, docstrings, low complexity, and clear error handling.",
44
- ],
45
- "low": [
46
- "Brute-force nested loops, missing validation, unsafe input handling, missing eval mode, missing no_grad, and code smells.",
47
- "Hard to maintain code with high complexity, repeated scans, mutable side effects, and unclear structure.",
48
- ],
49
- }
50
-
51
-
52
- class _HashEmbeddingBackend:
53
- """Torch-native fallback when pretrained weights cannot be loaded."""
54
-
55
- def __init__(self, dimensions: int = 128) -> None:
56
- self.dimensions = dimensions
57
- self.model_id = "hashed-token-fallback"
58
- self.backend_name = "hashed-token-fallback"
59
- self.notes = ["Using hashed embeddings because pretrained transformer weights are unavailable."]
60
-
61
- def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
62
- matrix = torch.zeros((len(texts), self.dimensions), dtype=torch.float32)
63
- for row_index, text in enumerate(texts):
64
- tokens = text.lower().split()[:512]
65
- if not tokens:
66
- matrix[row_index, 0] = 1.0
67
- continue
68
- for token in tokens:
69
- digest = hashlib.md5(token.encode("utf-8")).hexdigest()
70
- bucket = int(digest[:8], 16) % self.dimensions
71
- sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0
72
- matrix[row_index, bucket] += sign
73
- return F.normalize(matrix + 1e-6, dim=1)
74
-
75
-
76
- class PyTorchCodeAnalyzerModel:
77
- """Score code using pretrained transformer embeddings plus prototype similarity."""
78
-
79
- def __init__(self, model_id: str = "huggingface/CodeBERTa-small-v1") -> None:
80
- self.model_id = model_id
81
- self.backend_name = model_id
82
- self.notes: List[str] = []
83
- self._tokenizer = None
84
- self._model = None
85
- self._fallback = _HashEmbeddingBackend()
86
- self._prototype_cache: Dict[str, torch.Tensor] = {}
87
-
88
- def _ensure_loaded(self) -> None:
89
- if self._model is not None or self.notes:
90
- return
91
- if AutoTokenizer is None or AutoModel is None:
92
- self.backend_name = self._fallback.backend_name
93
- self.notes = list(self._fallback.notes)
94
- return
95
- try:
96
- self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
97
- self._model = AutoModel.from_pretrained(self.model_id)
98
- self._model.eval()
99
- self.notes.append(f"Loaded pretrained encoder `{self.model_id}`.")
100
- except Exception as exc:
101
- self.backend_name = self._fallback.backend_name
102
- self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {type(exc).__name__}: {exc}"]
103
-
104
- def _embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
105
- self._ensure_loaded()
106
- if self._model is None or self._tokenizer is None:
107
- return self._fallback.embed_texts(texts)
108
- encoded = self._tokenizer(list(texts), padding=True, truncation=True, max_length=256, return_tensors="pt")
109
- with torch.no_grad():
110
- outputs = self._model(**encoded)
111
- hidden = outputs.last_hidden_state
112
- mask = encoded["attention_mask"].unsqueeze(-1)
113
- pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
114
- return F.normalize(pooled, dim=1)
115
-
116
- def _prototype_matrix(self, bucket: str, texts: Sequence[str]) -> torch.Tensor:
117
- if bucket not in self._prototype_cache:
118
- self._prototype_cache[bucket] = self._embed_texts(texts)
119
- return self._prototype_cache[bucket]
120
-
121
- def predict(self, code: str, context_window: str, static_summary: Dict[str, object]) -> Dict[str, object]:
122
- """Predict domain probabilities and a model quality score."""
123
-
124
- document = (
125
- f"Code:\n{code.strip()[:4000]}\n\n"
126
- f"Context:\n{context_window.strip()[:1000]}\n\n"
127
- f"Static hints:\n{static_summary}\n"
128
- )
129
- candidate = self._embed_texts([document])
130
-
131
- domain_scores: Dict[str, float] = {}
132
- for domain, texts in DOMAIN_PROTOTYPES.items():
133
- matrix = self._prototype_matrix(f"domain:{domain}", texts)
134
- similarity = torch.matmul(candidate, matrix.T).max().item()
135
- domain_scores[domain] = round((similarity + 1.0) / 2.0, 4)
136
-
137
- high_matrix = self._prototype_matrix("quality:high", QUALITY_ANCHORS["high"])
138
- low_matrix = self._prototype_matrix("quality:low", QUALITY_ANCHORS["low"])
139
- high_similarity = torch.matmul(candidate, high_matrix.T).max().item()
140
- low_similarity = torch.matmul(candidate, low_matrix.T).max().item()
141
- ml_quality_score = torch.sigmoid(torch.tensor((high_similarity - low_similarity) * 4.0)).item()
142
-
143
- return {
144
- "domain_scores": domain_scores,
145
- "ml_quality_score": round(float(ml_quality_score), 4),
146
- "backend_name": self.backend_name,
147
- "model_id": self.model_id,
148
- "notes": list(self.notes),
149
- }
 
1
+ """PyTorch + transformers model wrapper for multi-domain code scoring."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ from typing import Dict, List, Sequence
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ try:
12
+ from transformers import AutoModel, AutoTokenizer
13
+ except Exception:
14
+ AutoModel = None # type: ignore[assignment]
15
+ AutoTokenizer = None # type: ignore[assignment]
16
+
17
+
18
+ DOMAIN_PROTOTYPES: Dict[str, List[str]] = {
19
+ "dsa": [
20
+ "Binary search, hashmap optimization, recursion, dynamic programming, arrays, trees, graphs, stack, queue, complexity.",
21
+ "Competitive programming algorithm with loops, memoization, prefix sums, and asymptotic analysis.",
22
+ ],
23
+ "data_science": [
24
+ "Pandas dataframe transformation, numpy vectorization, feature leakage, train test split, iterrows misuse.",
25
+ "Data cleaning pipeline using pandas, numpy, aggregation, joins, and vectorized operations.",
26
+ ],
27
+ "ml_dl": [
28
+ "PyTorch model, training loop, optimizer, backward pass, eval mode, no_grad, loss function, dataloader.",
29
+ "Machine learning inference and training code with torch, sklearn, tensors, gradients, and model checkpoints.",
30
+ ],
31
+ "web": [
32
+ "FastAPI endpoint, request validation, Pydantic models, async routes, API security, backend service design.",
33
+ "REST API backend with routers, dependency injection, input validation, serialization, and error handling.",
34
+ ],
35
+ "general": [
36
+ "General Python utility code with readable structure, typing, tests, and maintainable abstractions.",
37
+ ],
38
+ }
39
+
40
+ QUALITY_ANCHORS: Dict[str, List[str]] = {
41
+ "high": [
42
+ "Readable typed Python code with validation, efficient algorithms, vectorized operations, safe inference, and clean API boundaries.",
43
+ "Production-ready code with small functions, docstrings, low complexity, and clear error handling.",
44
+ ],
45
+ "low": [
46
+ "Brute-force nested loops, missing validation, unsafe input handling, missing eval mode, missing no_grad, and code smells.",
47
+ "Hard to maintain code with high complexity, repeated scans, mutable side effects, and unclear structure.",
48
+ ],
49
+ }
50
+
51
+
52
+ class _HashEmbeddingBackend:
53
+ """Torch-native fallback when pretrained weights cannot be loaded."""
54
+
55
+ def __init__(self, dimensions: int = 128) -> None:
56
+ self.dimensions = dimensions
57
+ self.model_id = "hashed-token-fallback"
58
+ self.backend_name = "hashed-token-fallback"
59
+ self.notes = ["Using hashed embeddings because pretrained transformer weights are unavailable."]
60
+
61
+ def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
62
+ matrix = torch.zeros((len(texts), self.dimensions), dtype=torch.float32)
63
+ for row_index, text in enumerate(texts):
64
+ tokens = text.lower().split()[:512]
65
+ if not tokens:
66
+ matrix[row_index, 0] = 1.0
67
+ continue
68
+ for token in tokens:
69
+ digest = hashlib.md5(token.encode("utf-8")).hexdigest()
70
+ bucket = int(digest[:8], 16) % self.dimensions
71
+ sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0
72
+ matrix[row_index, bucket] += sign
73
+ return F.normalize(matrix + 1e-6, dim=1)
74
+
75
+
76
+ class PyTorchCodeAnalyzerModel:
77
+ """Score code using pretrained transformer embeddings plus prototype similarity."""
78
+
79
+ def __init__(self, model_id: str = "huggingface/CodeBERTa-small-v1") -> None:
80
+ self.model_id = model_id
81
+ self.backend_name = model_id
82
+ self.notes: List[str] = []
83
+ self._tokenizer = None
84
+ self._model = None
85
+ self._fallback = _HashEmbeddingBackend()
86
+ self._prototype_cache: Dict[str, torch.Tensor] = {}
87
+
88
+ def _ensure_loaded(self) -> None:
89
+ if self._model is not None or self.notes:
90
+ return
91
+ if AutoTokenizer is None or AutoModel is None:
92
+ self.backend_name = self._fallback.backend_name
93
+ self.notes = list(self._fallback.notes)
94
+ return
95
+ try:
96
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
97
+ self._model = AutoModel.from_pretrained(self.model_id)
98
+ self._model.eval()
99
+ self.notes.append(f"Loaded pretrained encoder `{self.model_id}`.")
100
+ except Exception as exc:
101
+ self.backend_name = self._fallback.backend_name
102
+ self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {type(exc).__name__}: {exc}"]
103
+
104
+ def _embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
105
+ self._ensure_loaded()
106
+ if self._model is None or self._tokenizer is None:
107
+ return self._fallback.embed_texts(texts)
108
+ encoded = self._tokenizer(list(texts), padding=True, truncation=True, max_length=256, return_tensors="pt")
109
+ with torch.no_grad():
110
+ outputs = self._model(**encoded)
111
+ hidden = outputs.last_hidden_state
112
+ mask = encoded["attention_mask"].unsqueeze(-1)
113
+ pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
114
+ return F.normalize(pooled, dim=1)
115
+
116
+ def _prototype_matrix(self, bucket: str, texts: Sequence[str]) -> torch.Tensor:
117
+ if bucket not in self._prototype_cache:
118
+ self._prototype_cache[bucket] = self._embed_texts(texts)
119
+ return self._prototype_cache[bucket]
120
+
121
+ def predict(self, code: str, context_window: str, static_summary: Dict[str, object]) -> Dict[str, object]:
122
+ """Predict domain probabilities and a model quality score."""
123
+
124
+ document = (
125
+ f"Code:\n{code.strip()[:4000]}\n\n"
126
+ f"Context:\n{context_window.strip()[:1000]}\n\n"
127
+ f"Static hints:\n{static_summary}\n"
128
+ )
129
+ candidate = self._embed_texts([document])
130
+
131
+ domain_scores: Dict[str, float] = {}
132
+ for domain, texts in DOMAIN_PROTOTYPES.items():
133
+ matrix = self._prototype_matrix(f"domain:{domain}", texts)
134
+ similarity = torch.matmul(candidate, matrix.T).max().item()
135
+ domain_scores[domain] = round((similarity + 1.0) / 2.0, 4)
136
+
137
+ high_matrix = self._prototype_matrix("quality:high", QUALITY_ANCHORS["high"])
138
+ low_matrix = self._prototype_matrix("quality:low", QUALITY_ANCHORS["low"])
139
+ high_similarity = torch.matmul(candidate, high_matrix.T).max().item()
140
+ low_similarity = torch.matmul(candidate, low_matrix.T).max().item()
141
+ ml_quality_score = torch.sigmoid(torch.tensor((high_similarity - low_similarity) * 4.0)).item()
142
+
143
+ return {
144
+ "domain_scores": domain_scores,
145
+ "ml_quality_score": round(float(ml_quality_score), 4),
146
+ "backend_name": self.backend_name,
147
+ "model_id": self.model_id,
148
+ "notes": list(self.notes),
149
+ }
pyproject.toml CHANGED
@@ -1,16 +1,18 @@
1
- [build-system]
2
- requires = ["setuptools>=68", "wheel"]
3
- build-backend = "setuptools.build_meta"
4
-
5
  [project]
6
  name = "openenv-python-code-review-env"
7
  version = "1.0.0"
8
  description = "TorchReview Copilot: AI-powered Python code triage with PyTorch and OpenEnv validation."
9
  readme = "README.md"
10
  requires-python = ">=3.10"
 
11
  dependencies = [
12
  "fastapi>=0.111.0",
13
  "gradio>=5.26.0",
 
14
  "openai>=1.76.0",
15
  "openenv-core[core]>=0.2.2",
16
  "streamlit>=1.44.0",
@@ -24,28 +26,13 @@ dev = [
24
  "pytest>=8.0.0",
25
  "pytest-cov>=4.0.0",
26
  ]
27
-
28
- [project.scripts]
29
- server = "python_env.server.app:main"
30
-
31
- [tool.setuptools]
32
- include-package-data = true
33
- packages = [
34
- "python_env",
35
- "python_env.server",
36
- "python_env.tasks",
37
- "python_env.graders",
38
- "python_env.api",
39
- "python_env.app",
40
- "python_env.app.agents",
41
- "python_env.app.env",
42
- "python_env.app.models",
43
- "python_env.app.services",
44
- "python_env.app.utils",
45
- "python_env.analyzers",
46
- "python_env.models",
47
- "python_env.schemas",
48
- "python_env.services",
49
- "python_env.utils",
50
- ]
51
- package-dir = { "python_env" = ".", "python_env.server" = "server", "python_env.tasks" = "tasks", "python_env.graders" = "graders", "python_env.api" = "api", "python_env.app" = "app", "python_env.app.agents" = "app/agents", "python_env.app.env" = "app/env", "python_env.app.models" = "app/models", "python_env.app.services" = "app/services", "python_env.app.utils" = "app/utils", "python_env.analyzers" = "analyzers", "python_env.models" = "models", "python_env.schemas" = "schemas", "python_env.services" = "services", "python_env.utils" = "utils" }
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
  [project]
6
  name = "openenv-python-code-review-env"
7
  version = "1.0.0"
8
  description = "TorchReview Copilot: AI-powered Python code triage with PyTorch and OpenEnv validation."
9
  readme = "README.md"
10
  requires-python = ">=3.10"
11
+
12
  dependencies = [
13
  "fastapi>=0.111.0",
14
  "gradio>=5.26.0",
15
+ "hf-xet>=1.4.3",
16
  "openai>=1.76.0",
17
  "openenv-core[core]>=0.2.2",
18
  "streamlit>=1.44.0",
 
26
  "pytest>=8.0.0",
27
  "pytest-cov>=4.0.0",
28
  ]
29
+
30
+ [project.scripts]
31
+ server = "python_env.server.app:main"
32
+
33
+ [tool.setuptools]
34
+ include-package-data = true
35
+
36
+ [tool.setuptools.packages.find]
37
+ where = ["."]
38
+ include = ["*"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
schemas/__init__.py CHANGED
@@ -1,13 +1,13 @@
1
- """Public schemas for the multi-domain analysis platform."""
2
-
3
- from .request import AnalyzeCodeRequest
4
- from .response import AnalyzeCodeResponse, AnalysisIssue, DomainAnalysis, ScoreBreakdown, StaticAnalysisSummary
5
-
6
- __all__ = [
7
- "AnalyzeCodeRequest",
8
- "AnalyzeCodeResponse",
9
- "AnalysisIssue",
10
- "DomainAnalysis",
11
- "ScoreBreakdown",
12
- "StaticAnalysisSummary",
13
- ]
 
1
+ """Public schemas for the multi-domain analysis platform."""
2
+
3
+ from .request import AnalyzeCodeRequest
4
+ from .response import AnalyzeCodeResponse, AnalysisIssue, DomainAnalysis, ScoreBreakdown, StaticAnalysisSummary
5
+
6
+ __all__ = [
7
+ "AnalyzeCodeRequest",
8
+ "AnalyzeCodeResponse",
9
+ "AnalysisIssue",
10
+ "DomainAnalysis",
11
+ "ScoreBreakdown",
12
+ "StaticAnalysisSummary",
13
+ ]
schemas/request.py CHANGED
@@ -1,19 +1,19 @@
1
- """Request schemas for code analysis endpoints and UI."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Literal
6
-
7
- from pydantic import BaseModel, Field
8
-
9
-
10
- DomainHint = Literal["auto", "dsa", "data_science", "ml_dl", "web"]
11
-
12
-
13
- class AnalyzeCodeRequest(BaseModel):
14
- """Validated input payload for multi-domain code analysis."""
15
-
16
- code: str = Field(..., min_length=1, description="Source code to analyze.")
17
- context_window: str = Field(default="", max_length=2000, description="Optional repository or task context.")
18
- traceback_text: str = Field(default="", max_length=2000, description="Optional runtime or test failure output.")
19
- domain_hint: DomainHint = Field(default="auto", description="Optional domain override when auto detection is not desired.")
 
1
+ """Request schemas for code analysis endpoints and UI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ DomainHint = Literal["auto", "dsa", "data_science", "ml_dl", "web"]
11
+
12
+
13
+ class AnalyzeCodeRequest(BaseModel):
14
+ """Validated input payload for multi-domain code analysis."""
15
+
16
+ code: str = Field(..., min_length=1, description="Source code to analyze.")
17
+ context_window: str = Field(default="", max_length=2000, description="Optional repository or task context.")
18
+ traceback_text: str = Field(default="", max_length=2000, description="Optional runtime or test failure output.")
19
+ domain_hint: DomainHint = Field(default="auto", description="Optional domain override when auto detection is not desired.")
schemas/response.py CHANGED
@@ -1,73 +1,73 @@
1
- """Response schemas for the multi-domain analysis platform."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Dict, List, Literal
6
-
7
- from pydantic import BaseModel, Field
8
-
9
-
10
- DomainType = Literal["dsa", "data_science", "ml_dl", "web", "general"]
11
- Severity = Literal["low", "medium", "high"]
12
-
13
-
14
- class AnalysisIssue(BaseModel):
15
- """One detected issue or risk in the code snippet."""
16
-
17
- title: str
18
- severity: Severity
19
- description: str
20
- line_hint: int | None = None
21
-
22
-
23
- class StaticAnalysisSummary(BaseModel):
24
- """Language-agnostic static-analysis signals."""
25
-
26
- syntax_valid: bool
27
- syntax_error: str = ""
28
- cyclomatic_complexity: int = Field(..., ge=1)
29
- line_count: int = Field(..., ge=0)
30
- max_loop_depth: int = Field(..., ge=0)
31
- time_complexity: str = "Unknown"
32
- space_complexity: str = "Unknown"
33
- detected_imports: List[str] = Field(default_factory=list)
34
- code_smells: List[str] = Field(default_factory=list)
35
-
36
-
37
- class DomainAnalysis(BaseModel):
38
- """Domain-specific analysis payload returned by an analyzer."""
39
-
40
- domain: DomainType
41
- domain_score: float = Field(..., ge=0.0, le=1.0)
42
- issues: List[AnalysisIssue] = Field(default_factory=list)
43
- suggestions: List[str] = Field(default_factory=list)
44
- highlights: Dict[str, float | str] = Field(default_factory=dict)
45
-
46
-
47
- class ScoreBreakdown(BaseModel):
48
- """Reward inputs and final normalized score."""
49
-
50
- ml_score: float = Field(..., ge=0.0, le=1.0)
51
- domain_score: float = Field(..., ge=0.0, le=1.0)
52
- lint_score: float = Field(..., ge=0.0, le=1.0)
53
- complexity_penalty: float = Field(..., ge=0.0, le=1.0)
54
- quality_signal: float = Field(..., ge=0.0, le=1.0)
55
- error_reduction_signal: float = Field(..., ge=0.0, le=1.0)
56
- completion_signal: float = Field(..., ge=0.0, le=1.0)
57
- reward: float = Field(..., ge=0.0, le=1.0)
58
-
59
-
60
- class AnalyzeCodeResponse(BaseModel):
61
- """Top-level structured output for API and UI consumers."""
62
-
63
- detected_domain: DomainType
64
- domain_confidences: Dict[str, float]
65
- score_breakdown: ScoreBreakdown
66
- static_analysis: StaticAnalysisSummary
67
- domain_analysis: DomainAnalysis
68
- improvement_plan: List[str] = Field(default_factory=list)
69
- model_backend: str
70
- model_id: str
71
- summary: str
72
- context_window: str = ""
73
- analysis_time_ms: float = Field(..., ge=0.0)
 
1
+ """Response schemas for the multi-domain analysis platform."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Literal
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ DomainType = Literal["dsa", "data_science", "ml_dl", "web", "general"]
11
+ Severity = Literal["low", "medium", "high"]
12
+
13
+
14
+ class AnalysisIssue(BaseModel):
15
+ """One detected issue or risk in the code snippet."""
16
+
17
+ title: str
18
+ severity: Severity
19
+ description: str
20
+ line_hint: int | None = None
21
+
22
+
23
+ class StaticAnalysisSummary(BaseModel):
24
+ """Language-agnostic static-analysis signals."""
25
+
26
+ syntax_valid: bool
27
+ syntax_error: str = ""
28
+ cyclomatic_complexity: int = Field(..., ge=1)
29
+ line_count: int = Field(..., ge=0)
30
+ max_loop_depth: int = Field(..., ge=0)
31
+ time_complexity: str = "Unknown"
32
+ space_complexity: str = "Unknown"
33
+ detected_imports: List[str] = Field(default_factory=list)
34
+ code_smells: List[str] = Field(default_factory=list)
35
+
36
+
37
+ class DomainAnalysis(BaseModel):
38
+ """Domain-specific analysis payload returned by an analyzer."""
39
+
40
+ domain: DomainType
41
+ domain_score: float = Field(..., ge=0.0, le=1.0)
42
+ issues: List[AnalysisIssue] = Field(default_factory=list)
43
+ suggestions: List[str] = Field(default_factory=list)
44
+ highlights: Dict[str, float | str] = Field(default_factory=dict)
45
+
46
+
47
+ class ScoreBreakdown(BaseModel):
48
+ """Reward inputs and final normalized score."""
49
+
50
+ ml_score: float = Field(..., ge=0.0, le=1.0)
51
+ domain_score: float = Field(..., ge=0.0, le=1.0)
52
+ lint_score: float = Field(..., ge=0.0, le=1.0)
53
+ complexity_penalty: float = Field(..., ge=0.0, le=1.0)
54
+ quality_signal: float = Field(..., ge=0.0, le=1.0)
55
+ error_reduction_signal: float = Field(..., ge=0.0, le=1.0)
56
+ completion_signal: float = Field(..., ge=0.0, le=1.0)
57
+ reward: float = Field(..., ge=0.0, le=1.0)
58
+
59
+
60
+ class AnalyzeCodeResponse(BaseModel):
61
+ """Top-level structured output for API and UI consumers."""
62
+
63
+ detected_domain: DomainType
64
+ domain_confidences: Dict[str, float]
65
+ score_breakdown: ScoreBreakdown
66
+ static_analysis: StaticAnalysisSummary
67
+ domain_analysis: DomainAnalysis
68
+ improvement_plan: List[str] = Field(default_factory=list)
69
+ model_backend: str
70
+ model_id: str
71
+ summary: str
72
+ context_window: str = ""
73
+ analysis_time_ms: float = Field(..., ge=0.0)
server/Dockerfile CHANGED
@@ -1,27 +1,27 @@
1
- FROM python:3.11-slim
2
-
3
- ENV PYTHONDONTWRITEBYTECODE=1 \
4
- PYTHONUNBUFFERED=1 \
5
- PYTHONUTF8=1 \
6
- PYTHONIOENCODING=utf-8 \
7
- PIP_NO_CACHE_DIR=1 \
8
- PIP_DISABLE_PIP_VERSION_CHECK=1 \
9
- ENABLE_GRADIO_DEMO=false
10
-
11
- WORKDIR /app
12
-
13
- COPY server/requirements.txt /tmp/requirements.txt
14
-
15
- RUN python -m pip install --upgrade pip && \
16
- pip install -r /tmp/requirements.txt
17
-
18
- COPY . /app
19
-
20
- RUN pip install --no-deps .
21
-
22
- EXPOSE 8000
23
-
24
- HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
25
- CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=3).read()"
26
-
27
- CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PYTHONUTF8=1 \
6
+ PYTHONIOENCODING=utf-8 \
7
+ PIP_NO_CACHE_DIR=1 \
8
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
9
+ ENABLE_GRADIO_DEMO=false
10
+
11
+ WORKDIR /app
12
+
13
+ COPY server/requirements.txt /tmp/requirements.txt
14
+
15
+ RUN python -m pip install --upgrade pip && \
16
+ pip install -r /tmp/requirements.txt
17
+
18
+ COPY . /app
19
+
20
+ RUN pip install --no-deps .
21
+
22
+ EXPOSE 8000
23
+
24
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
25
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=3).read()"
26
+
27
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--no-access-log"]
server/app.py CHANGED
@@ -1,80 +1,86 @@
1
- """OpenEnv FastAPI entrypoint with optional Gradio mounting."""
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
-
7
- from fastapi import FastAPI
8
-
9
- try:
10
- from openenv.core.env_server.http_server import create_app
11
- except Exception as exc: # pragma: no cover
12
- raise ImportError(
13
- "openenv-core is required to run the API server. Install project dependencies first."
14
- ) from exc
15
-
16
- try:
17
- import gradio as gr
18
- except Exception:
19
- gr = None # type: ignore[assignment]
20
-
21
  try:
22
  from ..models import PythonCodeReviewAction, PythonCodeReviewObservation
23
  from .env import PythonCodeReviewEnvironment
24
  except ImportError:
25
  from models import PythonCodeReviewAction, PythonCodeReviewObservation
26
  from server.env import PythonCodeReviewEnvironment
27
-
28
-
29
- def _gradio_enabled() -> bool:
30
- for env_name in ("ENABLE_GRADIO_DEMO", "ENABLE_WEB_INTERFACE"):
31
- if str(os.getenv(env_name, "")).strip().lower() in {"1", "true", "yes", "on"}:
32
- return True
33
- return False
34
-
35
-
36
- def _max_concurrent_envs() -> int:
37
- try:
38
- return max(int(os.getenv("OPENENV_MAX_CONCURRENT_ENVS", "2")), 1)
39
- except Exception:
40
- return 2
41
-
42
-
43
  def build_application():
44
- """Compose the OpenEnv API with the Gradio demo frontend."""
45
-
46
- api_app = create_app(
47
- PythonCodeReviewEnvironment,
48
- PythonCodeReviewAction,
49
- PythonCodeReviewObservation,
50
- env_name="python_code_review_env",
51
- max_concurrent_envs=_max_concurrent_envs(),
52
- )
53
  served_app = api_app
54
  if gr is not None and _gradio_enabled():
55
  try:
56
- from .demo import build_demo
57
  except ImportError:
58
- from server.demo import build_demo
59
- served_app = gr.mount_gradio_app(api_app, build_demo(), path="/")
60
-
61
- wrapper_app = FastAPI(title="python_code_review_env", version="1.0.0")
62
-
63
- @wrapper_app.get("/health", include_in_schema=False)
64
- def _health() -> dict[str, str]:
65
- return {"status": "ok"}
66
-
67
- wrapper_app.mount("/", served_app)
68
- return wrapper_app
69
-
70
-
71
- app = build_application()
72
 
 
 
 
73
 
74
- def main(host: str = "0.0.0.0", port: int = 8000) -> None:
75
- import uvicorn
76
 
77
- uvicorn.run(app, host=host, port=port)
 
 
 
 
 
 
 
78
 
79
 
80
  if __name__ == "__main__":
 
1
+ """OpenEnv FastAPI entrypoint with optional Gradio mounting."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ from fastapi import FastAPI
8
+
9
+ try:
10
+ from openenv.core.env_server.http_server import create_app
11
+ except Exception as exc: # pragma: no cover
12
+ raise ImportError(
13
+ "openenv-core is required to run the API server. Install project dependencies first."
14
+ ) from exc
15
+
16
+ try:
17
+ import gradio as gr
18
+ except Exception:
19
+ gr = None # type: ignore[assignment]
20
+
21
  try:
22
  from ..models import PythonCodeReviewAction, PythonCodeReviewObservation
23
  from .env import PythonCodeReviewEnvironment
24
  except ImportError:
25
  from models import PythonCodeReviewAction, PythonCodeReviewObservation
26
  from server.env import PythonCodeReviewEnvironment
27
+
28
+
29
+ def _gradio_enabled() -> bool:
30
+ for env_name in ("ENABLE_GRADIO_DEMO", "ENABLE_WEB_INTERFACE"):
31
+ if str(os.getenv(env_name, "")).strip().lower() in {"1", "true", "yes", "on"}:
32
+ return True
33
+ return False
34
+
35
+
36
+ def _max_concurrent_envs() -> int:
37
+ try:
38
+ return max(int(os.getenv("OPENENV_MAX_CONCURRENT_ENVS", "2")), 1)
39
+ except Exception:
40
+ return 2
41
+
42
+
43
  def build_application():
44
+ """Compose the OpenEnv API with the Gradio demo frontend."""
45
+
46
+ api_app = create_app(
47
+ PythonCodeReviewEnvironment,
48
+ PythonCodeReviewAction,
49
+ PythonCodeReviewObservation,
50
+ env_name="python_code_review_env",
51
+ max_concurrent_envs=_max_concurrent_envs(),
52
+ )
53
  served_app = api_app
54
  if gr is not None and _gradio_enabled():
55
  try:
56
+ from .demo import CSS, build_demo
57
  except ImportError:
58
+ from server.demo import CSS, build_demo
59
+ served_app = gr.mount_gradio_app(
60
+ api_app,
61
+ build_demo(),
62
+ path="/",
63
+ theme=gr.themes.Soft(primary_hue="orange", secondary_hue="amber"),
64
+ css=CSS,
65
+ )
66
+
67
+ wrapper_app = FastAPI(title="python_code_review_env", version="1.0.0")
 
 
 
 
68
 
69
+ @wrapper_app.get("/health", include_in_schema=False)
70
+ def _health() -> dict[str, str]:
71
+ return {"status": "ok"}
72
 
73
+ wrapper_app.mount("/", served_app)
74
+ return wrapper_app
75
 
76
+
77
+ app = build_application()
78
+
79
+
80
+ def main(host: str = "0.0.0.0", port: int = 8000) -> None:
81
+ import uvicorn
82
+
83
+ uvicorn.run(app, host=host, port=port, access_log=False)
84
 
85
 
86
  if __name__ == "__main__":
server/demo.py CHANGED
@@ -1,441 +1,441 @@
1
- """Gradio UI for TorchReview Copilot."""
2
-
3
- from __future__ import annotations
4
-
5
- from html import escape
6
-
7
- import gradio as gr
8
-
9
- try:
10
- from ..triage import get_default_engine
11
- except ImportError:
12
- from triage import get_default_engine
13
-
14
-
15
- CSS = """
16
- :root {
17
- --paper: #f6f1e8;
18
- --ink: #162521;
19
- --accent: #d95d39;
20
- --panel: #fffdf8;
21
- --border: #d6c4b8;
22
- --muted: #5f6f67;
23
- --good: #2d7d62;
24
- --warn: #b76516;
25
- --high: #b23a48;
26
- }
27
-
28
- body, .gradio-container {
29
- background:
30
- radial-gradient(circle at top left, rgba(247, 197, 159, 0.35), transparent 35%),
31
- linear-gradient(135deg, #f9f6ef 0%, #efe5d3 100%);
32
- color: var(--ink);
33
- font-family: Georgia, "Times New Roman", serif;
34
- }
35
-
36
- .gradio-container {
37
- max-width: 1260px !important;
38
- }
39
-
40
- .hero-card,
41
- .metric-card,
42
- .subtle-card {
43
- background: rgba(255, 253, 248, 0.95);
44
- border: 1px solid var(--border);
45
- border-radius: 20px;
46
- box-shadow: 0 16px 40px rgba(22, 37, 33, 0.08);
47
- }
48
-
49
- .hero-card {
50
- padding: 28px 30px;
51
- margin-bottom: 12px;
52
- }
53
-
54
- .metric-card,
55
- .subtle-card {
56
- padding: 20px 22px;
57
- }
58
-
59
- .eyebrow {
60
- text-transform: uppercase;
61
- letter-spacing: 0.12em;
62
- font-size: 12px;
63
- color: var(--accent);
64
- margin-bottom: 10px;
65
- }
66
-
67
- .hero-title {
68
- font-size: 44px;
69
- line-height: 1.05;
70
- margin: 0 0 10px;
71
- }
72
-
73
- .hero-copy {
74
- margin: 0;
75
- font-size: 18px;
76
- line-height: 1.55;
77
- color: var(--muted);
78
- }
79
-
80
- .summary-title {
81
- display: flex;
82
- justify-content: space-between;
83
- gap: 12px;
84
- align-items: center;
85
- margin-bottom: 14px;
86
- }
87
-
88
- .pill {
89
- display: inline-block;
90
- padding: 6px 12px;
91
- border-radius: 999px;
92
- font-size: 12px;
93
- text-transform: uppercase;
94
- letter-spacing: 0.08em;
95
- background: #efe5d3;
96
- }
97
-
98
- .pill.low { color: var(--good); }
99
- .pill.medium { color: var(--warn); }
100
- .pill.high { color: var(--high); }
101
-
102
- .summary-grid {
103
- display: grid;
104
- grid-template-columns: repeat(2, minmax(0, 1fr));
105
- gap: 12px;
106
- margin-top: 16px;
107
- }
108
-
109
- .summary-stat {
110
- background: #fff7ef;
111
- border-radius: 14px;
112
- padding: 12px 14px;
113
- border: 1px solid rgba(214, 196, 184, 0.8);
114
- }
115
-
116
- .summary-stat strong {
117
- display: block;
118
- font-size: 12px;
119
- text-transform: uppercase;
120
- letter-spacing: 0.08em;
121
- color: var(--muted);
122
- margin-bottom: 6px;
123
- }
124
-
125
- .radar-wrap {
126
- display: grid;
127
- gap: 12px;
128
- }
129
-
130
- .bar {
131
- display: grid;
132
- gap: 6px;
133
- }
134
-
135
- .bar-head {
136
- display: flex;
137
- justify-content: space-between;
138
- font-size: 13px;
139
- color: var(--muted);
140
- }
141
-
142
- .bar-track {
143
- width: 100%;
144
- height: 12px;
145
- background: #f2e5d6;
146
- border-radius: 999px;
147
- overflow: hidden;
148
- }
149
-
150
- .bar-fill {
151
- height: 100%;
152
- border-radius: 999px;
153
- }
154
-
155
- .matched-box {
156
- background: #fff7ef;
157
- border: 1px solid rgba(214, 196, 184, 0.8);
158
- border-radius: 16px;
159
- padding: 14px;
160
- }
161
-
162
- .how-grid {
163
- display: grid;
164
- grid-template-columns: repeat(4, minmax(0, 1fr));
165
- gap: 12px;
166
- }
167
-
168
- .how-step {
169
- background: rgba(255, 253, 248, 0.9);
170
- border: 1px solid var(--border);
171
- border-radius: 18px;
172
- padding: 16px;
173
- }
174
-
175
- @media (max-width: 900px) {
176
- .hero-title {
177
- font-size: 34px;
178
- }
179
-
180
- .summary-grid,
181
- .how-grid {
182
- grid-template-columns: 1fr;
183
- }
184
- }
185
- """
186
-
187
-
188
- def _default_outputs() -> tuple[str, str, str, str, str]:
189
- return (
190
- "<div class='metric-card'><div class='eyebrow'>Awaiting Analysis</div><p class='hero-copy'>Paste Python code, add an optional traceback, or load one of the built-in examples.</p></div>",
191
- "<div class='metric-card'><div class='eyebrow'>Live Triage Radar</div><p class='hero-copy'>Confidence bars will appear after the first analysis run.</p></div>",
192
- "### Improvement Plan\nAnalyze a sample to generate syntax, edge-case, and scalability recommendations.",
193
- "### Known Pattern Match\nThe nearest OpenEnv task will be highlighted here after inference runs.",
194
- "### Model Notes\nBackend and extracted signal details will appear here.",
195
- )
196
-
197
-
198
- def _summary_html(result) -> str:
199
- issue = escape(result.issue_label.title())
200
- summary = escape(result.summary)
201
- next_action = escape(result.suggested_next_action)
202
- return f"""
203
- <div class="metric-card">
204
- <div class="summary-title">
205
- <div>
206
- <div class="eyebrow">TorchReview Verdict</div>
207
- <h3 style="margin:0;font-size:30px;">{issue} Issue</h3>
208
- </div>
209
- <span class="pill {escape(result.repair_risk)}">{escape(result.repair_risk)} repair risk</span>
210
- </div>
211
- <p class="hero-copy">{summary}</p>
212
- <div class="summary-grid">
213
- <div class="summary-stat">
214
- <strong>Reward Score</strong>
215
- {result.reward_score:.0%}
216
- </div>
217
- <div class="summary-stat">
218
- <strong>ML Quality</strong>
219
- {result.ml_quality_score:.0%}
220
- </div>
221
- <div class="summary-stat">
222
- <strong>Matched Pattern</strong>
223
- {escape(result.matched_pattern.title)}
224
- </div>
225
- <div class="summary-stat">
226
- <strong>Inference Backend</strong>
227
- {escape(result.model_backend)}
228
- </div>
229
- <div class="summary-stat">
230
- <strong>Lint Score</strong>
231
- {result.lint_score:.0%}
232
- </div>
233
- <div class="summary-stat">
234
- <strong>Complexity Penalty</strong>
235
- {result.complexity_penalty:.0%}
236
- </div>
237
- <div class="summary-stat">
238
- <strong>Next Action</strong>
239
- {next_action}
240
- </div>
241
- </div>
242
- </div>
243
- """
244
-
245
-
246
- def _radar_html(result) -> str:
247
- colors = {
248
- "syntax": "#d95d39",
249
- "logic": "#4f772d",
250
- "performance": "#355070",
251
- }
252
- bars = []
253
- for label, score in result.confidence_scores.items():
254
- bars.append(
255
- f"""
256
- <div class="bar">
257
- <div class="bar-head"><span>{escape(label.title())}</span><span>{score:.0%}</span></div>
258
- <div class="bar-track">
259
- <div class="bar-fill" style="width:{score * 100:.1f}%; background:{colors.get(label, '#d95d39')};"></div>
260
- </div>
261
- </div>
262
- """
263
- )
264
- return f"""
265
- <div class="metric-card radar-wrap">
266
- <div class="eyebrow">Live Triage Radar</div>
267
- {''.join(bars)}
268
- <div class="matched-box">
269
- <strong>Nearest Known Pattern:</strong> {escape(result.matched_pattern.title)}<br>
270
- <span style="color:#5f6f67;">{escape(result.matched_pattern.summary)}</span>
271
- </div>
272
- </div>
273
- """
274
-
275
-
276
- def _plan_markdown(result) -> str:
277
- plan_lines = "\n".join(f"{index + 1}. {step}" for index, step in enumerate(result.repair_plan))
278
- return (
279
- "### Improvement Plan\n"
280
- f"**Primary issue:** `{result.issue_label}`\n\n"
281
- f"{plan_lines}\n\n"
282
- f"**Suggested next action:** {result.suggested_next_action}"
283
- )
284
-
285
-
286
- def _match_markdown(result) -> str:
287
- return (
288
- "### Known Pattern Match\n"
289
- f"**Task:** `{result.matched_pattern.task_id}` \n"
290
- f"**Title:** {result.matched_pattern.title} \n"
291
- f"**Why it matched:** {result.matched_pattern.rationale} \n"
292
- f"**Similarity:** {result.matched_pattern.similarity:.0%}"
293
- )
294
-
295
-
296
- def _model_markdown(result) -> str:
297
- signal_lines = "\n".join(
298
- f"- `{signal.name}` -> {signal.value} ({signal.impact}, weight {signal.weight:.2f}): {signal.evidence}"
299
- for signal in result.extracted_signals
300
- ) or "- No strong static signals were extracted."
301
- notes = "\n".join(f"- {item}" for item in result.inference_notes) or "- No additional backend notes."
302
- return (
303
- "### Model Notes\n"
304
- f"- **Model backend:** `{result.model_backend}`\n"
305
- f"- **Model id:** `{result.model_id}`\n"
306
- f"- **Analysis time:** `{result.analysis_time_ms:.2f} ms`\n\n"
307
- "### Reward Formula\n"
308
- f"- `reward = (0.5 x {result.ml_quality_score:.2f}) + (0.3 x {result.lint_score:.2f}) - (0.2 x {result.complexity_penalty:.2f})`\n"
309
- f"- **Final reward:** `{result.reward_score:.2f}`\n\n"
310
- "### Extracted Signals\n"
311
- f"{signal_lines}\n\n"
312
- "### Backend Notes\n"
313
- f"{notes}"
314
- )
315
-
316
-
317
- def analyze_inputs(code: str, traceback_text: str, context_window: str) -> tuple[str, str, str, str, str]:
318
- """Run the triage engine and format outputs for the Gradio UI."""
319
-
320
- result = get_default_engine().triage(code or "", traceback_text or "", context_window or "")
321
- return (
322
- _summary_html(result),
323
- _radar_html(result),
324
- _plan_markdown(result),
325
- _match_markdown(result),
326
- _model_markdown(result),
327
- )
328
-
329
-
330
- def load_example(example_key: str) -> tuple[str, str, str, str, str, str, str, str, str]:
331
- """Populate the UI from a built-in example and immediately analyze it."""
332
-
333
- example = get_default_engine().example_map()[example_key]
334
- outputs = analyze_inputs(example.code, example.traceback_text, example.context_window)
335
- header = (
336
- f"### Example Scenario\n"
337
- f"**{example.title}** \n"
338
- f"{example.summary} \n"
339
- f"Label target: `{example.label}`"
340
- )
341
- return (example.code, example.traceback_text, example.context_window, header, *outputs)
342
-
343
-
344
- def build_demo() -> gr.Blocks:
345
- """Create the TorchReview Copilot Gradio application."""
346
-
347
- examples = get_default_engine().example_map()
348
- first_example = next(iter(examples.values()))
349
-
350
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="amber"), css=CSS, title="TorchReview Copilot") as demo:
351
- gr.HTML(
352
- """
353
- <div class="hero-card">
354
- <div class="eyebrow">Meta PyTorch OpenEnv Hackathon Demo</div>
355
- <h1 class="hero-title">TorchReview Copilot</h1>
356
- <p class="hero-copy">
357
- AI-powered code review and improvement system using PyTorch to score code quality, surface bugs,
358
- and generate a three-step improvement plan. OpenEnv stays underneath as the deterministic validation engine.
359
- </p>
360
- </div>
361
- """
362
- )
363
-
364
- with gr.Row():
365
- with gr.Column(scale=6):
366
- example_choice = gr.Radio(
367
- choices=[(item.title, item.key) for item in examples.values()],
368
- value=first_example.key,
369
- label="Try a built-in failure scenario",
370
- info="Switching examples updates the Live Triage Radar immediately.",
371
- )
372
- example_header = gr.Markdown()
373
- code_input = gr.Code(
374
- value=first_example.code,
375
- language="python",
376
- lines=18,
377
- label="Python code under review",
378
- )
379
- traceback_input = gr.Textbox(
380
- value=first_example.traceback_text,
381
- lines=7,
382
- label="Optional traceback / failing test output",
383
- placeholder="Paste stack traces, assertion failures, or benchmark notes here.",
384
- )
385
- context_input = gr.Textbox(
386
- value=first_example.context_window,
387
- lines=4,
388
- label="Context window",
389
- placeholder="Describe expected behavior, constraints, or repository context.",
390
- )
391
- with gr.Row():
392
- analyze_button = gr.Button("Analyze & Score Code", variant="primary")
393
- clear_button = gr.Button("Clear Inputs", variant="secondary")
394
-
395
- with gr.Column(scale=5):
396
- summary_html = gr.HTML()
397
- radar_html = gr.HTML()
398
- plan_markdown = gr.Markdown()
399
- match_markdown = gr.Markdown()
400
- model_markdown = gr.Markdown()
401
-
402
- gr.HTML(
403
- """
404
- <div class="subtle-card" style="margin-top: 12px;">
405
- <div class="eyebrow">How It Works</div>
406
- <div class="how-grid">
407
- <div class="how-step"><strong>Input</strong><br>Code plus optional traceback or benchmark signal.</div>
408
- <div class="how-step"><strong>Processing</strong><br>Static checks extract parser, lint, complexity, and runtime clues.</div>
409
- <div class="how-step"><strong>Model</strong><br>CodeBERTa embeddings run through PyTorch and score code quality against known OpenEnv patterns.</div>
410
- <div class="how-step"><strong>Output</strong><br>Confidence radar, reward score, and a three-step improvement plan.</div>
411
- </div>
412
- </div>
413
- """
414
- )
415
-
416
- example_choice.change(
417
- fn=load_example,
418
- inputs=example_choice,
419
- outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
420
- show_progress="hidden",
421
- )
422
- analyze_button.click(
423
- fn=analyze_inputs,
424
- inputs=[code_input, traceback_input, context_input],
425
- outputs=[summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
426
- show_progress="minimal",
427
- )
428
- clear_button.click(
429
- fn=lambda: ("", "", "", "### Example Scenario\nChoose a built-in example or paste custom code.", *_default_outputs()),
430
- inputs=None,
431
- outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
432
- show_progress="hidden",
433
- )
434
- demo.load(
435
- fn=load_example,
436
- inputs=example_choice,
437
- outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
438
- show_progress="hidden",
439
- )
440
-
441
- return demo
 
1
+ """Gradio UI for TorchReview Copilot."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from html import escape
6
+
7
+ import gradio as gr
8
+
9
+ try:
10
+ from ..triage import get_default_engine
11
+ except ImportError:
12
+ from triage import get_default_engine
13
+
14
+
15
+ CSS = """
16
+ :root {
17
+ --paper: #f6f1e8;
18
+ --ink: #162521;
19
+ --accent: #d95d39;
20
+ --panel: #fffdf8;
21
+ --border: #d6c4b8;
22
+ --muted: #5f6f67;
23
+ --good: #2d7d62;
24
+ --warn: #b76516;
25
+ --high: #b23a48;
26
+ }
27
+
28
+ body, .gradio-container {
29
+ background:
30
+ radial-gradient(circle at top left, rgba(247, 197, 159, 0.35), transparent 35%),
31
+ linear-gradient(135deg, #f9f6ef 0%, #efe5d3 100%);
32
+ color: var(--ink);
33
+ font-family: Georgia, "Times New Roman", serif;
34
+ }
35
+
36
+ .gradio-container {
37
+ max-width: 1260px !important;
38
+ }
39
+
40
+ .hero-card,
41
+ .metric-card,
42
+ .subtle-card {
43
+ background: rgba(255, 253, 248, 0.95);
44
+ border: 1px solid var(--border);
45
+ border-radius: 20px;
46
+ box-shadow: 0 16px 40px rgba(22, 37, 33, 0.08);
47
+ }
48
+
49
+ .hero-card {
50
+ padding: 28px 30px;
51
+ margin-bottom: 12px;
52
+ }
53
+
54
+ .metric-card,
55
+ .subtle-card {
56
+ padding: 20px 22px;
57
+ }
58
+
59
+ .eyebrow {
60
+ text-transform: uppercase;
61
+ letter-spacing: 0.12em;
62
+ font-size: 12px;
63
+ color: var(--accent);
64
+ margin-bottom: 10px;
65
+ }
66
+
67
+ .hero-title {
68
+ font-size: 44px;
69
+ line-height: 1.05;
70
+ margin: 0 0 10px;
71
+ }
72
+
73
+ .hero-copy {
74
+ margin: 0;
75
+ font-size: 18px;
76
+ line-height: 1.55;
77
+ color: var(--muted);
78
+ }
79
+
80
+ .summary-title {
81
+ display: flex;
82
+ justify-content: space-between;
83
+ gap: 12px;
84
+ align-items: center;
85
+ margin-bottom: 14px;
86
+ }
87
+
88
+ .pill {
89
+ display: inline-block;
90
+ padding: 6px 12px;
91
+ border-radius: 999px;
92
+ font-size: 12px;
93
+ text-transform: uppercase;
94
+ letter-spacing: 0.08em;
95
+ background: #efe5d3;
96
+ }
97
+
98
+ .pill.low { color: var(--good); }
99
+ .pill.medium { color: var(--warn); }
100
+ .pill.high { color: var(--high); }
101
+
102
+ .summary-grid {
103
+ display: grid;
104
+ grid-template-columns: repeat(2, minmax(0, 1fr));
105
+ gap: 12px;
106
+ margin-top: 16px;
107
+ }
108
+
109
+ .summary-stat {
110
+ background: #fff7ef;
111
+ border-radius: 14px;
112
+ padding: 12px 14px;
113
+ border: 1px solid rgba(214, 196, 184, 0.8);
114
+ }
115
+
116
+ .summary-stat strong {
117
+ display: block;
118
+ font-size: 12px;
119
+ text-transform: uppercase;
120
+ letter-spacing: 0.08em;
121
+ color: var(--muted);
122
+ margin-bottom: 6px;
123
+ }
124
+
125
+ .radar-wrap {
126
+ display: grid;
127
+ gap: 12px;
128
+ }
129
+
130
+ .bar {
131
+ display: grid;
132
+ gap: 6px;
133
+ }
134
+
135
+ .bar-head {
136
+ display: flex;
137
+ justify-content: space-between;
138
+ font-size: 13px;
139
+ color: var(--muted);
140
+ }
141
+
142
+ .bar-track {
143
+ width: 100%;
144
+ height: 12px;
145
+ background: #f2e5d6;
146
+ border-radius: 999px;
147
+ overflow: hidden;
148
+ }
149
+
150
+ .bar-fill {
151
+ height: 100%;
152
+ border-radius: 999px;
153
+ }
154
+
155
+ .matched-box {
156
+ background: #fff7ef;
157
+ border: 1px solid rgba(214, 196, 184, 0.8);
158
+ border-radius: 16px;
159
+ padding: 14px;
160
+ }
161
+
162
+ .how-grid {
163
+ display: grid;
164
+ grid-template-columns: repeat(4, minmax(0, 1fr));
165
+ gap: 12px;
166
+ }
167
+
168
+ .how-step {
169
+ background: rgba(255, 253, 248, 0.9);
170
+ border: 1px solid var(--border);
171
+ border-radius: 18px;
172
+ padding: 16px;
173
+ }
174
+
175
+ @media (max-width: 900px) {
176
+ .hero-title {
177
+ font-size: 34px;
178
+ }
179
+
180
+ .summary-grid,
181
+ .how-grid {
182
+ grid-template-columns: 1fr;
183
+ }
184
+ }
185
+ """
186
+
187
+
188
+ def _default_outputs() -> tuple[str, str, str, str, str]:
189
+ return (
190
+ "<div class='metric-card'><div class='eyebrow'>Awaiting Analysis</div><p class='hero-copy'>Paste Python code, add an optional traceback, or load one of the built-in examples.</p></div>",
191
+ "<div class='metric-card'><div class='eyebrow'>Live Triage Radar</div><p class='hero-copy'>Confidence bars will appear after the first analysis run.</p></div>",
192
+ "### Improvement Plan\nAnalyze a sample to generate syntax, edge-case, and scalability recommendations.",
193
+ "### Known Pattern Match\nThe nearest OpenEnv task will be highlighted here after inference runs.",
194
+ "### Model Notes\nBackend and extracted signal details will appear here.",
195
+ )
196
+
197
+
198
+ def _summary_html(result) -> str:
199
+ issue = escape(result.issue_label.title())
200
+ summary = escape(result.summary)
201
+ next_action = escape(result.suggested_next_action)
202
+ return f"""
203
+ <div class="metric-card">
204
+ <div class="summary-title">
205
+ <div>
206
+ <div class="eyebrow">TorchReview Verdict</div>
207
+ <h3 style="margin:0;font-size:30px;">{issue} Issue</h3>
208
+ </div>
209
+ <span class="pill {escape(result.repair_risk)}">{escape(result.repair_risk)} repair risk</span>
210
+ </div>
211
+ <p class="hero-copy">{summary}</p>
212
+ <div class="summary-grid">
213
+ <div class="summary-stat">
214
+ <strong>Reward Score</strong>
215
+ {result.reward_score:.0%}
216
+ </div>
217
+ <div class="summary-stat">
218
+ <strong>ML Quality</strong>
219
+ {result.ml_quality_score:.0%}
220
+ </div>
221
+ <div class="summary-stat">
222
+ <strong>Matched Pattern</strong>
223
+ {escape(result.matched_pattern.title)}
224
+ </div>
225
+ <div class="summary-stat">
226
+ <strong>Inference Backend</strong>
227
+ {escape(result.model_backend)}
228
+ </div>
229
+ <div class="summary-stat">
230
+ <strong>Lint Score</strong>
231
+ {result.lint_score:.0%}
232
+ </div>
233
+ <div class="summary-stat">
234
+ <strong>Complexity Penalty</strong>
235
+ {result.complexity_penalty:.0%}
236
+ </div>
237
+ <div class="summary-stat">
238
+ <strong>Next Action</strong>
239
+ {next_action}
240
+ </div>
241
+ </div>
242
+ </div>
243
+ """
244
+
245
+
246
+ def _radar_html(result) -> str:
247
+ colors = {
248
+ "syntax": "#d95d39",
249
+ "logic": "#4f772d",
250
+ "performance": "#355070",
251
+ }
252
+ bars = []
253
+ for label, score in result.confidence_scores.items():
254
+ bars.append(
255
+ f"""
256
+ <div class="bar">
257
+ <div class="bar-head"><span>{escape(label.title())}</span><span>{score:.0%}</span></div>
258
+ <div class="bar-track">
259
+ <div class="bar-fill" style="width:{score * 100:.1f}%; background:{colors.get(label, '#d95d39')};"></div>
260
+ </div>
261
+ </div>
262
+ """
263
+ )
264
+ return f"""
265
+ <div class="metric-card radar-wrap">
266
+ <div class="eyebrow">Live Triage Radar</div>
267
+ {''.join(bars)}
268
+ <div class="matched-box">
269
+ <strong>Nearest Known Pattern:</strong> {escape(result.matched_pattern.title)}<br>
270
+ <span style="color:#5f6f67;">{escape(result.matched_pattern.summary)}</span>
271
+ </div>
272
+ </div>
273
+ """
274
+
275
+
276
+ def _plan_markdown(result) -> str:
277
+ plan_lines = "\n".join(f"{index + 1}. {step}" for index, step in enumerate(result.repair_plan))
278
+ return (
279
+ "### Improvement Plan\n"
280
+ f"**Primary issue:** `{result.issue_label}`\n\n"
281
+ f"{plan_lines}\n\n"
282
+ f"**Suggested next action:** {result.suggested_next_action}"
283
+ )
284
+
285
+
286
+ def _match_markdown(result) -> str:
287
+ return (
288
+ "### Known Pattern Match\n"
289
+ f"**Task:** `{result.matched_pattern.task_id}` \n"
290
+ f"**Title:** {result.matched_pattern.title} \n"
291
+ f"**Why it matched:** {result.matched_pattern.rationale} \n"
292
+ f"**Similarity:** {result.matched_pattern.similarity:.0%}"
293
+ )
294
+
295
+
296
+ def _model_markdown(result) -> str:
297
+ signal_lines = "\n".join(
298
+ f"- `{signal.name}` -> {signal.value} ({signal.impact}, weight {signal.weight:.2f}): {signal.evidence}"
299
+ for signal in result.extracted_signals
300
+ ) or "- No strong static signals were extracted."
301
+ notes = "\n".join(f"- {item}" for item in result.inference_notes) or "- No additional backend notes."
302
+ return (
303
+ "### Model Notes\n"
304
+ f"- **Model backend:** `{result.model_backend}`\n"
305
+ f"- **Model id:** `{result.model_id}`\n"
306
+ f"- **Analysis time:** `{result.analysis_time_ms:.2f} ms`\n\n"
307
+ "### Reward Formula\n"
308
+ f"- `reward = (0.5 x {result.ml_quality_score:.2f}) + (0.3 x {result.lint_score:.2f}) - (0.2 x {result.complexity_penalty:.2f})`\n"
309
+ f"- **Final reward:** `{result.reward_score:.2f}`\n\n"
310
+ "### Extracted Signals\n"
311
+ f"{signal_lines}\n\n"
312
+ "### Backend Notes\n"
313
+ f"{notes}"
314
+ )
315
+
316
+
317
+ def analyze_inputs(code: str, traceback_text: str, context_window: str) -> tuple[str, str, str, str, str]:
318
+ """Run the triage engine and format outputs for the Gradio UI."""
319
+
320
+ result = get_default_engine().triage(code or "", traceback_text or "", context_window or "")
321
+ return (
322
+ _summary_html(result),
323
+ _radar_html(result),
324
+ _plan_markdown(result),
325
+ _match_markdown(result),
326
+ _model_markdown(result),
327
+ )
328
+
329
+
330
+ def load_example(example_key: str) -> tuple[str, str, str, str, str, str, str, str, str]:
331
+ """Populate the UI from a built-in example and immediately analyze it."""
332
+
333
+ example = get_default_engine().example_map()[example_key]
334
+ outputs = analyze_inputs(example.code, example.traceback_text, example.context_window)
335
+ header = (
336
+ f"### Example Scenario\n"
337
+ f"**{example.title}** \n"
338
+ f"{example.summary} \n"
339
+ f"Label target: `{example.label}`"
340
+ )
341
+ return (example.code, example.traceback_text, example.context_window, header, *outputs)
342
+
343
+
344
+ def build_demo() -> gr.Blocks:
345
+ """Create the TorchReview Copilot Gradio application."""
346
+
347
+ examples = get_default_engine().example_map()
348
+ first_example = next(iter(examples.values()))
349
+
350
+ with gr.Blocks(title="TorchReview Copilot") as demo:
351
+ gr.HTML(
352
+ """
353
+ <div class="hero-card">
354
+ <div class="eyebrow">Meta PyTorch OpenEnv Hackathon Demo</div>
355
+ <h1 class="hero-title">TorchReview Copilot</h1>
356
+ <p class="hero-copy">
357
+ AI-powered code review and improvement system using PyTorch to score code quality, surface bugs,
358
+ and generate a three-step improvement plan. OpenEnv stays underneath as the deterministic validation engine.
359
+ </p>
360
+ </div>
361
+ """
362
+ )
363
+
364
+ with gr.Row():
365
+ with gr.Column(scale=6):
366
+ example_choice = gr.Radio(
367
+ choices=[(item.title, item.key) for item in examples.values()],
368
+ value=first_example.key,
369
+ label="Try a built-in failure scenario",
370
+ info="Switching examples updates the Live Triage Radar immediately.",
371
+ )
372
+ example_header = gr.Markdown()
373
+ code_input = gr.Code(
374
+ value=first_example.code,
375
+ language="python",
376
+ lines=18,
377
+ label="Python code under review",
378
+ )
379
+ traceback_input = gr.Textbox(
380
+ value=first_example.traceback_text,
381
+ lines=7,
382
+ label="Optional traceback / failing test output",
383
+ placeholder="Paste stack traces, assertion failures, or benchmark notes here.",
384
+ )
385
+ context_input = gr.Textbox(
386
+ value=first_example.context_window,
387
+ lines=4,
388
+ label="Context window",
389
+ placeholder="Describe expected behavior, constraints, or repository context.",
390
+ )
391
+ with gr.Row():
392
+ analyze_button = gr.Button("Analyze & Score Code", variant="primary")
393
+ clear_button = gr.Button("Clear Inputs", variant="secondary")
394
+
395
+ with gr.Column(scale=5):
396
+ summary_html = gr.HTML()
397
+ radar_html = gr.HTML()
398
+ plan_markdown = gr.Markdown()
399
+ match_markdown = gr.Markdown()
400
+ model_markdown = gr.Markdown()
401
+
402
+ gr.HTML(
403
+ """
404
+ <div class="subtle-card" style="margin-top: 12px;">
405
+ <div class="eyebrow">How It Works</div>
406
+ <div class="how-grid">
407
+ <div class="how-step"><strong>Input</strong><br>Code plus optional traceback or benchmark signal.</div>
408
+ <div class="how-step"><strong>Processing</strong><br>Static checks extract parser, lint, complexity, and runtime clues.</div>
409
+ <div class="how-step"><strong>Model</strong><br>CodeBERTa embeddings run through PyTorch and score code quality against known OpenEnv patterns.</div>
410
+ <div class="how-step"><strong>Output</strong><br>Confidence radar, reward score, and a three-step improvement plan.</div>
411
+ </div>
412
+ </div>
413
+ """
414
+ )
415
+
416
+ example_choice.change(
417
+ fn=load_example,
418
+ inputs=example_choice,
419
+ outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
420
+ show_progress="hidden",
421
+ )
422
+ analyze_button.click(
423
+ fn=analyze_inputs,
424
+ inputs=[code_input, traceback_input, context_input],
425
+ outputs=[summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
426
+ show_progress="minimal",
427
+ )
428
+ clear_button.click(
429
+ fn=lambda: ("", "", "", "### Example Scenario\nChoose a built-in example or paste custom code.", *_default_outputs()),
430
+ inputs=None,
431
+ outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
432
+ show_progress="hidden",
433
+ )
434
+ demo.load(
435
+ fn=load_example,
436
+ inputs=example_choice,
437
+ outputs=[code_input, traceback_input, context_input, example_header, summary_html, radar_html, plan_markdown, match_markdown, model_markdown],
438
+ show_progress="hidden",
439
+ )
440
+
441
+ return demo
server/env.py CHANGED
@@ -11,24 +11,24 @@ from openenv.core.env_server.types import EnvironmentMetadata
11
  try:
12
  from ..graders import grade_task
13
  from ..graders.shared import component_score, safe_ratio, strict_score
14
- from ..models import (
15
- HistoryEntry,
16
- PythonCodeReviewAction,
17
- PythonCodeReviewObservation,
18
- PythonCodeReviewState,
19
- RewardDetails,
20
  TaskGrade,
21
  )
22
  from ..tasks import ReviewTask, list_tasks, select_task
23
  except ImportError:
24
  from graders import grade_task
25
  from graders.shared import component_score, safe_ratio, strict_score
26
- from models import (
27
- HistoryEntry,
28
- PythonCodeReviewAction,
29
- PythonCodeReviewObservation,
30
- PythonCodeReviewState,
31
- RewardDetails,
32
  TaskGrade,
33
  )
34
  from tasks import ReviewTask, list_tasks, select_task
@@ -56,17 +56,17 @@ class PythonCodeReviewEnvironment(
56
 
57
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
58
 
59
- def __init__(self, verbose: bool = False, **_: Any) -> None:
60
- super().__init__()
61
- self.verbose = verbose
62
- self._task: ReviewTask = list_tasks()[0]
63
- self._current_code: str = self._task.starter_code
64
- self._history: list[HistoryEntry] = []
65
- self._last_reward = RewardDetails(value=0.1, reason="Environment initialized.")
66
- self._last_action_error: str | None = None
67
- self._current_grade = _empty_grade()
68
- self._state = PythonCodeReviewState(episode_id=str(uuid4()), step_count=0)
69
- self.reset()
70
 
71
  def reset(
72
  self,
@@ -74,17 +74,17 @@ class PythonCodeReviewEnvironment(
74
  episode_id: Optional[str] = None,
75
  **kwargs: Any,
76
  ) -> PythonCodeReviewObservation:
77
- task_id = kwargs.get("task_id")
78
- self._task = select_task(seed=seed, task_id=task_id)
79
- self._current_code = self._task.starter_code
80
- self._history = []
81
- self._last_action_error = None
82
- self._last_reward = RewardDetails(value=0.1, reason="Environment reset.")
83
- self._current_grade, self._last_action_error = self._safe_grade_task(
84
- self._task,
85
- self._current_code,
86
- include_hidden=False,
87
- )
88
 
89
  self._state = PythonCodeReviewState(
90
  episode_id=episode_id or str(uuid4()),
@@ -143,22 +143,22 @@ class PythonCodeReviewEnvironment(
143
  )
144
  return observation, reward.value, observation.done, {"task_id": observation.task_id, "score": observation.score}
145
 
146
- previous_grade = self._current_grade
147
- status = ""
148
- invalid_action = False
149
- code_changed = False
150
- use_hidden_grading = False
151
- action_error: str | None = None
152
-
153
- if action.action_type == "edit_code":
154
- if not action.code or not action.code.strip():
155
- invalid_action = True
156
- status = "edit_code requires a non-empty code payload."
157
- action_error = status
158
- else:
159
- code_changed = action.code != self._current_code
160
- self._current_code = action.code
161
- status = "Updated working copy from agent patch."
162
  elif action.action_type == "submit_solution":
163
  if action.code is not None and action.code.strip():
164
  code_changed = action.code != self._current_code
@@ -169,30 +169,30 @@ class PythonCodeReviewEnvironment(
169
  status = "Executed public validation suite."
170
  elif action.action_type == "analyze_code":
171
  status = "Generated static review summary."
172
- else: # pragma: no cover
173
- invalid_action = True
174
- status = f"Unsupported action_type: {action.action_type}"
175
- action_error = status
176
 
177
  self._state.step_count += 1
178
 
179
- if invalid_action:
180
- current_grade = previous_grade
181
- else:
182
- current_grade, grade_error = self._safe_grade_task(
183
- self._task,
184
- self._current_code,
185
- include_hidden=use_hidden_grading,
186
- timeout_s=timeout_s or 3.0,
187
- )
188
- if grade_error:
189
- action_error = grade_error
190
- status = f"{status} Grading fallback used."
191
- if action.action_type == "analyze_code":
192
- status = self._analysis_status(current_grade)
193
- elif action.action_type == "run_tests":
194
- status = self._run_tests_status(current_grade, use_hidden_grading)
195
- elif action.action_type == "submit_solution":
196
  status = self._submission_status(current_grade)
197
 
198
  done = use_hidden_grading or self._state.step_count >= self._task.max_steps
@@ -217,11 +217,11 @@ class PythonCodeReviewEnvironment(
217
  reward=reward_details.value,
218
  )
219
  )
220
-
221
- self._current_grade = current_grade
222
- self._last_reward = reward_details
223
- self._last_action_error = action_error
224
- attempts_remaining = max(self._task.max_steps - self._state.step_count, 0)
225
 
226
  self._state.task_id = self._task.task_id
227
  self._state.difficulty = self._task.difficulty
@@ -234,19 +234,19 @@ class PythonCodeReviewEnvironment(
234
  self._state.score = current_grade.score
235
  self._state.done = done
236
 
237
- observation = self._build_observation(
238
- grade=current_grade,
239
- status=status,
240
- reward_details=reward_details,
241
- )
242
- return observation, reward_details.value, observation.done, {
243
- "task_id": observation.task_id,
244
- "score": observation.score,
245
- "done": observation.done,
246
- "attempts_remaining": observation.attempts_remaining,
247
- "last_action_status": observation.last_action_status,
248
- "last_action_error": observation.last_action_error,
249
- }
250
 
251
  @property
252
  def state(self) -> PythonCodeReviewState:
@@ -268,22 +268,22 @@ class PythonCodeReviewEnvironment(
268
  current_code=self._current_code,
269
  errors=self._format_errors(grade),
270
  test_results=self._format_test_results(grade),
271
- visible_tests=list(self._task.visible_tests),
272
- history=list(self._history),
273
- attempts_remaining=self._state.attempts_remaining,
274
- last_action_status=status,
275
- last_action_error=self._last_action_error,
276
- score=grade.score,
277
- reward=reward_details.value,
278
- done=self._state.done,
279
- reward_details=reward_details,
280
- metadata={
281
- "benchmark": "python_code_review_env",
282
- "goal": self._task.goal,
283
- "repo_summary": self._task.repo_summary,
284
- "changed_files": self._task.changed_files,
285
- "available_files": self._task.available_files,
286
- "grade_details": grade.details,
287
  },
288
  )
289
 
@@ -298,43 +298,43 @@ class PythonCodeReviewEnvironment(
298
  code_changed: bool,
299
  final_submission: bool,
300
  ) -> RewardDetails:
301
- prev_score = previous_grade.score
302
- curr_score = current_grade.score
303
- prev_rate = safe_ratio(previous_grade.tests_passed, previous_grade.tests_total)
304
- curr_rate = safe_ratio(current_grade.tests_passed, current_grade.tests_total)
305
- prev_runtime = previous_grade.runtime_score
306
- curr_runtime = current_grade.runtime_score
307
- prev_compile_error = bool(str(previous_grade.details.get("compile_error", "")).strip())
308
- curr_compile_error = bool(str(current_grade.details.get("compile_error", "")).strip())
309
-
310
- syntax_reward = 0.14 if previous_grade.syntax_score < 0.9 and current_grade.syntax_score >= 0.9 else 0.0
311
- test_reward = round(max(curr_rate - prev_rate, 0.0) * 0.28, 3)
312
- progress_delta = round(max(curr_score - prev_score, 0.0) * 0.3, 3)
313
- quality_bonus = round(max(current_grade.quality_score - previous_grade.quality_score, 0.0) * 0.12, 3)
314
- runtime_bonus = round(max(curr_runtime - prev_runtime, 0.0) * 0.08, 3)
315
- error_reduction_bonus = 0.1 if prev_compile_error and not curr_compile_error else 0.0
316
- completion_bonus = 0.14 if final_submission and curr_rate >= 0.999 and curr_score >= 0.94 else 0.0
317
- correctness_bonus = 0.12 if final_submission and curr_score >= 0.94 and prev_score < 0.94 else 0.0
318
-
319
- invalid_action_penalty = round((0.04 + (0.08 * (1.0 - prev_score))) if invalid_action else 0.0, 3)
320
- timeout_penalty = round((0.06 + (0.08 * max(curr_runtime, prev_runtime))) if timed_out else 0.0, 3)
321
- regression_penalty = round(max(prev_score - curr_score, 0.0) * 0.25, 3)
322
- stagnation_penalty = round((0.02 + (0.05 * prev_score)) if action.action_type == "edit_code" and not code_changed else 0.0, 3)
323
-
324
- raw_value = (
325
- 0.32 * curr_score
326
- + syntax_reward
327
- + test_reward
328
- + progress_delta
329
- + quality_bonus
330
- + error_reduction_bonus
331
- + completion_bonus
332
- + runtime_bonus
333
- + correctness_bonus
334
- - invalid_action_penalty
335
- - timeout_penalty
336
- - regression_penalty
337
- - stagnation_penalty
338
  )
339
  value = _reward_value(raw_value)
340
 
@@ -345,16 +345,16 @@ class PythonCodeReviewEnvironment(
345
  reason_parts.append("public test progress")
346
  if progress_delta:
347
  reason_parts.append("overall score improved")
348
- if quality_bonus:
349
- reason_parts.append("code quality improved")
350
- if error_reduction_bonus:
351
- reason_parts.append("errors removed")
352
- if completion_bonus:
353
- reason_parts.append("task completed")
354
- if runtime_bonus:
355
- reason_parts.append("runtime improved")
356
- if correctness_bonus:
357
- reason_parts.append("full correctness bonus")
358
  if invalid_action_penalty:
359
  reason_parts.append("invalid action penalty")
360
  if timeout_penalty:
@@ -368,48 +368,48 @@ class PythonCodeReviewEnvironment(
368
 
369
  return RewardDetails(
370
  value=value,
371
- syntax_reward=syntax_reward,
372
- test_reward=test_reward,
373
- correctness_bonus=correctness_bonus,
374
- quality_bonus=quality_bonus,
375
- error_reduction_bonus=error_reduction_bonus,
376
- completion_bonus=completion_bonus,
377
- runtime_bonus=runtime_bonus,
378
- progress_delta=progress_delta,
379
- invalid_action_penalty=invalid_action_penalty,
380
- timeout_penalty=timeout_penalty,
381
- regression_penalty=regression_penalty,
382
- stagnation_penalty=stagnation_penalty,
383
  reason=", ".join(reason_parts),
384
  prev_score=prev_score,
385
  curr_score=curr_score,
386
  code_changed=code_changed,
387
  )
388
 
389
- def _format_errors(self, grade: TaskGrade) -> str:
390
- compile_error = str(grade.details.get("compile_error", "")).strip()
391
- if compile_error:
392
- return compile_error
393
- return "Code parses successfully."
394
-
395
- def _safe_grade_task(
396
- self,
397
- task: ReviewTask,
398
- code: str,
399
- *,
400
- include_hidden: bool,
401
- timeout_s: float = 3.0,
402
- ) -> tuple[TaskGrade, str | None]:
403
- try:
404
- return (
405
- grade_task(task, code, include_hidden=include_hidden, timeout_s=timeout_s),
406
- None,
407
- )
408
- except Exception as exc: # pragma: no cover
409
- return _empty_grade(), f"{type(exc).__name__}: {exc}"
410
-
411
- def _format_test_results(self, grade: TaskGrade) -> str:
412
- parts = [grade.details.get("test_summary", "No test feedback available.")]
413
  benchmark = grade.details.get("benchmark")
414
  if isinstance(benchmark, dict):
415
  parts.append(
 
11
  try:
12
  from ..graders import grade_task
13
  from ..graders.shared import component_score, safe_ratio, strict_score
14
+ from ..models import (
15
+ HistoryEntry,
16
+ PythonCodeReviewAction,
17
+ PythonCodeReviewObservation,
18
+ PythonCodeReviewState,
19
+ RewardDetails,
20
  TaskGrade,
21
  )
22
  from ..tasks import ReviewTask, list_tasks, select_task
23
  except ImportError:
24
  from graders import grade_task
25
  from graders.shared import component_score, safe_ratio, strict_score
26
+ from models import (
27
+ HistoryEntry,
28
+ PythonCodeReviewAction,
29
+ PythonCodeReviewObservation,
30
+ PythonCodeReviewState,
31
+ RewardDetails,
32
  TaskGrade,
33
  )
34
  from tasks import ReviewTask, list_tasks, select_task
 
56
 
57
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
58
 
59
+ def __init__(self, verbose: bool = False, **_: Any) -> None:
60
+ super().__init__()
61
+ self.verbose = verbose
62
+ self._task: ReviewTask = list_tasks()[0]
63
+ self._current_code: str = self._task.starter_code
64
+ self._history: list[HistoryEntry] = []
65
+ self._last_reward = RewardDetails(value=0.1, reason="Environment initialized.")
66
+ self._last_action_error: str | None = None
67
+ self._current_grade = _empty_grade()
68
+ self._state = PythonCodeReviewState(episode_id=str(uuid4()), step_count=0)
69
+ self.reset()
70
 
71
  def reset(
72
  self,
 
74
  episode_id: Optional[str] = None,
75
  **kwargs: Any,
76
  ) -> PythonCodeReviewObservation:
77
+ task_id = kwargs.get("task_id")
78
+ self._task = select_task(seed=seed, task_id=task_id)
79
+ self._current_code = self._task.starter_code
80
+ self._history = []
81
+ self._last_action_error = None
82
+ self._last_reward = RewardDetails(value=0.1, reason="Environment reset.")
83
+ self._current_grade, self._last_action_error = self._safe_grade_task(
84
+ self._task,
85
+ self._current_code,
86
+ include_hidden=False,
87
+ )
88
 
89
  self._state = PythonCodeReviewState(
90
  episode_id=episode_id or str(uuid4()),
 
143
  )
144
  return observation, reward.value, observation.done, {"task_id": observation.task_id, "score": observation.score}
145
 
146
+ previous_grade = self._current_grade
147
+ status = ""
148
+ invalid_action = False
149
+ code_changed = False
150
+ use_hidden_grading = False
151
+ action_error: str | None = None
152
+
153
+ if action.action_type == "edit_code":
154
+ if not action.code or not action.code.strip():
155
+ invalid_action = True
156
+ status = "edit_code requires a non-empty code payload."
157
+ action_error = status
158
+ else:
159
+ code_changed = action.code != self._current_code
160
+ self._current_code = action.code
161
+ status = "Updated working copy from agent patch."
162
  elif action.action_type == "submit_solution":
163
  if action.code is not None and action.code.strip():
164
  code_changed = action.code != self._current_code
 
169
  status = "Executed public validation suite."
170
  elif action.action_type == "analyze_code":
171
  status = "Generated static review summary."
172
+ else: # pragma: no cover
173
+ invalid_action = True
174
+ status = f"Unsupported action_type: {action.action_type}"
175
+ action_error = status
176
 
177
  self._state.step_count += 1
178
 
179
+ if invalid_action:
180
+ current_grade = previous_grade
181
+ else:
182
+ current_grade, grade_error = self._safe_grade_task(
183
+ self._task,
184
+ self._current_code,
185
+ include_hidden=use_hidden_grading,
186
+ timeout_s=timeout_s or 3.0,
187
+ )
188
+ if grade_error:
189
+ action_error = grade_error
190
+ status = f"{status} Grading fallback used."
191
+ if action.action_type == "analyze_code":
192
+ status = self._analysis_status(current_grade)
193
+ elif action.action_type == "run_tests":
194
+ status = self._run_tests_status(current_grade, use_hidden_grading)
195
+ elif action.action_type == "submit_solution":
196
  status = self._submission_status(current_grade)
197
 
198
  done = use_hidden_grading or self._state.step_count >= self._task.max_steps
 
217
  reward=reward_details.value,
218
  )
219
  )
220
+
221
+ self._current_grade = current_grade
222
+ self._last_reward = reward_details
223
+ self._last_action_error = action_error
224
+ attempts_remaining = max(self._task.max_steps - self._state.step_count, 0)
225
 
226
  self._state.task_id = self._task.task_id
227
  self._state.difficulty = self._task.difficulty
 
234
  self._state.score = current_grade.score
235
  self._state.done = done
236
 
237
+ observation = self._build_observation(
238
+ grade=current_grade,
239
+ status=status,
240
+ reward_details=reward_details,
241
+ )
242
+ return observation, reward_details.value, observation.done, {
243
+ "task_id": observation.task_id,
244
+ "score": observation.score,
245
+ "done": observation.done,
246
+ "attempts_remaining": observation.attempts_remaining,
247
+ "last_action_status": observation.last_action_status,
248
+ "last_action_error": observation.last_action_error,
249
+ }
250
 
251
  @property
252
  def state(self) -> PythonCodeReviewState:
 
268
  current_code=self._current_code,
269
  errors=self._format_errors(grade),
270
  test_results=self._format_test_results(grade),
271
+ visible_tests=list(self._task.visible_tests),
272
+ history=list(self._history),
273
+ attempts_remaining=self._state.attempts_remaining,
274
+ last_action_status=status,
275
+ last_action_error=self._last_action_error,
276
+ score=grade.score,
277
+ reward=reward_details.value,
278
+ done=self._state.done,
279
+ reward_details=reward_details,
280
+ metadata={
281
+ "benchmark": "python_code_review_env",
282
+ "goal": self._task.goal,
283
+ "repo_summary": self._task.repo_summary,
284
+ "changed_files": self._task.changed_files,
285
+ "available_files": self._task.available_files,
286
+ "grade_details": grade.details,
287
  },
288
  )
289
 
 
298
  code_changed: bool,
299
  final_submission: bool,
300
  ) -> RewardDetails:
301
+ prev_score = previous_grade.score
302
+ curr_score = current_grade.score
303
+ prev_rate = safe_ratio(previous_grade.tests_passed, previous_grade.tests_total)
304
+ curr_rate = safe_ratio(current_grade.tests_passed, current_grade.tests_total)
305
+ prev_runtime = previous_grade.runtime_score
306
+ curr_runtime = current_grade.runtime_score
307
+ prev_compile_error = bool(str(previous_grade.details.get("compile_error", "")).strip())
308
+ curr_compile_error = bool(str(current_grade.details.get("compile_error", "")).strip())
309
+
310
+ syntax_reward = 0.14 if previous_grade.syntax_score < 0.9 and current_grade.syntax_score >= 0.9 else 0.0
311
+ test_reward = round(max(curr_rate - prev_rate, 0.0) * 0.28, 3)
312
+ progress_delta = round(max(curr_score - prev_score, 0.0) * 0.3, 3)
313
+ quality_bonus = round(max(current_grade.quality_score - previous_grade.quality_score, 0.0) * 0.12, 3)
314
+ runtime_bonus = round(max(curr_runtime - prev_runtime, 0.0) * 0.08, 3)
315
+ error_reduction_bonus = 0.1 if prev_compile_error and not curr_compile_error else 0.0
316
+ completion_bonus = 0.14 if final_submission and curr_rate >= 0.999 and curr_score >= 0.94 else 0.0
317
+ correctness_bonus = 0.12 if final_submission and curr_score >= 0.94 and prev_score < 0.94 else 0.0
318
+
319
+ invalid_action_penalty = round((0.04 + (0.08 * (1.0 - prev_score))) if invalid_action else 0.0, 3)
320
+ timeout_penalty = round((0.06 + (0.08 * max(curr_runtime, prev_runtime))) if timed_out else 0.0, 3)
321
+ regression_penalty = round(max(prev_score - curr_score, 0.0) * 0.25, 3)
322
+ stagnation_penalty = round((0.02 + (0.05 * prev_score)) if action.action_type == "edit_code" and not code_changed else 0.0, 3)
323
+
324
+ raw_value = (
325
+ 0.32 * curr_score
326
+ + syntax_reward
327
+ + test_reward
328
+ + progress_delta
329
+ + quality_bonus
330
+ + error_reduction_bonus
331
+ + completion_bonus
332
+ + runtime_bonus
333
+ + correctness_bonus
334
+ - invalid_action_penalty
335
+ - timeout_penalty
336
+ - regression_penalty
337
+ - stagnation_penalty
338
  )
339
  value = _reward_value(raw_value)
340
 
 
345
  reason_parts.append("public test progress")
346
  if progress_delta:
347
  reason_parts.append("overall score improved")
348
+ if quality_bonus:
349
+ reason_parts.append("code quality improved")
350
+ if error_reduction_bonus:
351
+ reason_parts.append("errors removed")
352
+ if completion_bonus:
353
+ reason_parts.append("task completed")
354
+ if runtime_bonus:
355
+ reason_parts.append("runtime improved")
356
+ if correctness_bonus:
357
+ reason_parts.append("full correctness bonus")
358
  if invalid_action_penalty:
359
  reason_parts.append("invalid action penalty")
360
  if timeout_penalty:
 
368
 
369
  return RewardDetails(
370
  value=value,
371
+ syntax_reward=syntax_reward,
372
+ test_reward=test_reward,
373
+ correctness_bonus=correctness_bonus,
374
+ quality_bonus=quality_bonus,
375
+ error_reduction_bonus=error_reduction_bonus,
376
+ completion_bonus=completion_bonus,
377
+ runtime_bonus=runtime_bonus,
378
+ progress_delta=progress_delta,
379
+ invalid_action_penalty=invalid_action_penalty,
380
+ timeout_penalty=timeout_penalty,
381
+ regression_penalty=regression_penalty,
382
+ stagnation_penalty=stagnation_penalty,
383
  reason=", ".join(reason_parts),
384
  prev_score=prev_score,
385
  curr_score=curr_score,
386
  code_changed=code_changed,
387
  )
388
 
389
+ def _format_errors(self, grade: TaskGrade) -> str:
390
+ compile_error = str(grade.details.get("compile_error", "")).strip()
391
+ if compile_error:
392
+ return compile_error
393
+ return "Code parses successfully."
394
+
395
+ def _safe_grade_task(
396
+ self,
397
+ task: ReviewTask,
398
+ code: str,
399
+ *,
400
+ include_hidden: bool,
401
+ timeout_s: float = 3.0,
402
+ ) -> tuple[TaskGrade, str | None]:
403
+ try:
404
+ return (
405
+ grade_task(task, code, include_hidden=include_hidden, timeout_s=timeout_s),
406
+ None,
407
+ )
408
+ except Exception as exc: # pragma: no cover
409
+ return _empty_grade(), f"{type(exc).__name__}: {exc}"
410
+
411
+ def _format_test_results(self, grade: TaskGrade) -> str:
412
+ parts = [grade.details.get("test_summary", "No test feedback available.")]
413
  benchmark = grade.details.get("benchmark")
414
  if isinstance(benchmark, dict):
415
  parts.append(
server/requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- openenv-core[core]>=0.2.2
2
- fastapi>=0.111.0
3
- gradio>=5.26.0
4
- uvicorn>=0.30.0
5
- openai>=1.76.0
6
- streamlit>=1.44.0
7
- torch>=2.2.0
8
- transformers>=4.45.0
 
1
+ openenv-core[core]>=0.2.2
2
+ fastapi>=0.111.0
3
+ gradio>=5.26.0
4
+ uvicorn>=0.30.0
5
+ openai>=1.76.0
6
+ streamlit>=1.44.0
7
+ torch>=2.2.0
8
+ transformers>=4.45.0
services/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
- """Service layer for orchestrating analysis, suggestions, and rewards."""
2
-
3
- from .analysis_service import AnalysisService
4
- from .reward_service import RewardService
5
- from .suggestion_service import SuggestionService
6
-
7
- __all__ = ["AnalysisService", "RewardService", "SuggestionService"]
 
1
+ """Service layer for orchestrating analysis, suggestions, and rewards."""
2
+
3
+ from .analysis_service import AnalysisService
4
+ from .reward_service import RewardService
5
+ from .suggestion_service import SuggestionService
6
+
7
+ __all__ = ["AnalysisService", "RewardService", "SuggestionService"]
services/analysis_service.py CHANGED
@@ -1,139 +1,139 @@
1
- """Orchestration layer for multi-domain code analysis."""
2
-
3
- from __future__ import annotations
4
-
5
- import time
6
- from typing import Any, Callable, Dict
7
-
8
- from analyzers import analyze_data_science_code, analyze_dsa_code, analyze_ml_code, analyze_web_code
9
- from models import PyTorchCodeAnalyzerModel
10
- from schemas.request import AnalyzeCodeRequest
11
- from schemas.response import AnalyzeCodeResponse, DomainAnalysis, StaticAnalysisSummary
12
- from services.reward_service import RewardService
13
- from services.suggestion_service import SuggestionService
14
- from utils import estimate_complexity, parse_code_structure
15
-
16
-
17
- def _lint_score(parsed: Dict[str, Any]) -> float:
18
- """Convert structural smells into a normalized lint-style score."""
19
-
20
- score = 1.0
21
- if not parsed.get("syntax_valid", True):
22
- score -= 0.45
23
- score -= min(parsed.get("long_lines", 0), 5) * 0.03
24
- if parsed.get("tabs_used"):
25
- score -= 0.1
26
- if parsed.get("trailing_whitespace_lines"):
27
- score -= 0.05
28
- if parsed.get("docstring_ratio", 0.0) == 0.0 and parsed.get("function_names"):
29
- score -= 0.08
30
- return round(max(0.0, min(1.0, score)), 4)
31
-
32
-
33
- class AnalysisService:
34
- """End-to-end analysis pipeline shared by API and UI."""
35
-
36
- def __init__(self) -> None:
37
- self._model: PyTorchCodeAnalyzerModel | None = None
38
- self.reward_service = RewardService()
39
- self.suggestion_service = SuggestionService()
40
- self._analyzers: Dict[str, Callable[[str, Dict[str, Any], Dict[str, Any]], DomainAnalysis]] = {
41
- "dsa": analyze_dsa_code,
42
- "data_science": analyze_data_science_code,
43
- "ml_dl": analyze_ml_code,
44
- "web": analyze_web_code,
45
- }
46
-
47
- @property
48
- def model(self) -> PyTorchCodeAnalyzerModel:
49
- if self._model is None:
50
- self._model = PyTorchCodeAnalyzerModel()
51
- return self._model
52
-
53
- def _heuristic_domain_scores(self, parsed: Dict[str, Any], code: str) -> Dict[str, float]:
54
- """Derive domain priors from imports and syntax-level hints."""
55
-
56
- scores = {
57
- "dsa": 0.2 + (0.15 if parsed.get("uses_recursion") else 0.0) + (0.15 if parsed.get("max_loop_depth", 0) >= 1 else 0.0),
58
- "data_science": 0.2 + (0.35 if parsed.get("uses_pandas") or parsed.get("uses_numpy") else 0.0),
59
- "ml_dl": 0.2 + (0.35 if parsed.get("uses_torch") or parsed.get("uses_sklearn") else 0.0),
60
- "web": 0.2 + (0.35 if parsed.get("uses_fastapi") or parsed.get("uses_flask") else 0.0) + (0.1 if parsed.get("route_decorators") else 0.0),
61
- "general": 0.2,
62
- }
63
- if "fastapi" in code.lower():
64
- scores["web"] += 0.1
65
- if "pandas" in code.lower() or "numpy" in code.lower():
66
- scores["data_science"] += 0.1
67
- if "torch" in code.lower():
68
- scores["ml_dl"] += 0.1
69
- if "while" in code or "for" in code:
70
- scores["dsa"] += 0.05
71
- return {key: round(min(value, 0.99), 4) for key, value in scores.items()}
72
-
73
- def analyze(self, request: AnalyzeCodeRequest) -> AnalyzeCodeResponse:
74
- """Run the complete multi-domain analysis pipeline."""
75
-
76
- started = time.perf_counter()
77
- parsed = parse_code_structure(request.code)
78
- complexity = estimate_complexity(parsed, request.code)
79
- model_prediction = self.model.predict(request.code, request.context_window, parsed)
80
- heuristic_scores = self._heuristic_domain_scores(parsed, request.code)
81
-
82
- combined_scores = {}
83
- for domain, heuristic_score in heuristic_scores.items():
84
- model_score = float(model_prediction["domain_scores"].get(domain, 0.2))
85
- combined_scores[domain] = round((0.6 * model_score) + (0.4 * heuristic_score), 4)
86
-
87
- detected_domain = request.domain_hint if request.domain_hint != "auto" else max(combined_scores, key=combined_scores.get)
88
- analyzer = self._analyzers.get(detected_domain)
89
- domain_analysis = (
90
- analyzer(request.code, parsed, complexity)
91
- if analyzer is not None
92
- else DomainAnalysis(
93
- domain="general",
94
- domain_score=0.6,
95
- issues=[],
96
- suggestions=["Add stronger domain-specific context for deeper analysis."],
97
- highlights={},
98
- )
99
- )
100
-
101
- lint_score = _lint_score(parsed)
102
- score_breakdown = self.reward_service.compute(
103
- ml_score=float(model_prediction["ml_quality_score"]),
104
- domain_score=domain_analysis.domain_score,
105
- lint_score=lint_score,
106
- complexity_penalty=float(complexity["complexity_penalty"]),
107
- )
108
- static_analysis = StaticAnalysisSummary(
109
- syntax_valid=bool(parsed["syntax_valid"]),
110
- syntax_error=str(parsed["syntax_error"]),
111
- cyclomatic_complexity=int(complexity["cyclomatic_complexity"]),
112
- line_count=int(parsed["line_count"]),
113
- max_loop_depth=int(parsed["max_loop_depth"]),
114
- time_complexity=str(complexity["time_complexity"]),
115
- space_complexity=str(complexity["space_complexity"]),
116
- detected_imports=list(parsed["imports"]),
117
- code_smells=list(parsed["code_smells"]),
118
- )
119
- improvement_plan = self.suggestion_service.build_improvement_plan(
120
- domain_analysis=domain_analysis,
121
- static_analysis=static_analysis,
122
- )
123
- summary = (
124
- f"Detected `{detected_domain}` code with a model score of {score_breakdown.ml_score:.0%}, "
125
- f"domain score {score_breakdown.domain_score:.0%}, and final reward {score_breakdown.reward:.0%}."
126
- )
127
- return AnalyzeCodeResponse(
128
- detected_domain=detected_domain, # type: ignore[arg-type]
129
- domain_confidences=combined_scores,
130
- score_breakdown=score_breakdown,
131
- static_analysis=static_analysis,
132
- domain_analysis=domain_analysis,
133
- improvement_plan=improvement_plan,
134
- model_backend=str(model_prediction["backend_name"]),
135
- model_id=str(model_prediction["model_id"]),
136
- summary=summary,
137
- context_window=request.context_window,
138
- analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2),
139
- )
 
1
+ """Orchestration layer for multi-domain code analysis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import Any, Callable, Dict
7
+
8
+ from analyzers import analyze_data_science_code, analyze_dsa_code, analyze_ml_code, analyze_web_code
9
+ from models import PyTorchCodeAnalyzerModel
10
+ from schemas.request import AnalyzeCodeRequest
11
+ from schemas.response import AnalyzeCodeResponse, DomainAnalysis, StaticAnalysisSummary
12
+ from services.reward_service import RewardService
13
+ from services.suggestion_service import SuggestionService
14
+ from utils import estimate_complexity, parse_code_structure
15
+
16
+
17
+ def _lint_score(parsed: Dict[str, Any]) -> float:
18
+ """Convert structural smells into a normalized lint-style score."""
19
+
20
+ score = 1.0
21
+ if not parsed.get("syntax_valid", True):
22
+ score -= 0.45
23
+ score -= min(parsed.get("long_lines", 0), 5) * 0.03
24
+ if parsed.get("tabs_used"):
25
+ score -= 0.1
26
+ if parsed.get("trailing_whitespace_lines"):
27
+ score -= 0.05
28
+ if parsed.get("docstring_ratio", 0.0) == 0.0 and parsed.get("function_names"):
29
+ score -= 0.08
30
+ return round(max(0.0, min(1.0, score)), 4)
31
+
32
+
33
+ class AnalysisService:
34
+ """End-to-end analysis pipeline shared by API and UI."""
35
+
36
+ def __init__(self) -> None:
37
+ self._model: PyTorchCodeAnalyzerModel | None = None
38
+ self.reward_service = RewardService()
39
+ self.suggestion_service = SuggestionService()
40
+ self._analyzers: Dict[str, Callable[[str, Dict[str, Any], Dict[str, Any]], DomainAnalysis]] = {
41
+ "dsa": analyze_dsa_code,
42
+ "data_science": analyze_data_science_code,
43
+ "ml_dl": analyze_ml_code,
44
+ "web": analyze_web_code,
45
+ }
46
+
47
+ @property
48
+ def model(self) -> PyTorchCodeAnalyzerModel:
49
+ if self._model is None:
50
+ self._model = PyTorchCodeAnalyzerModel()
51
+ return self._model
52
+
53
+ def _heuristic_domain_scores(self, parsed: Dict[str, Any], code: str) -> Dict[str, float]:
54
+ """Derive domain priors from imports and syntax-level hints."""
55
+
56
+ scores = {
57
+ "dsa": 0.2 + (0.15 if parsed.get("uses_recursion") else 0.0) + (0.15 if parsed.get("max_loop_depth", 0) >= 1 else 0.0),
58
+ "data_science": 0.2 + (0.35 if parsed.get("uses_pandas") or parsed.get("uses_numpy") else 0.0),
59
+ "ml_dl": 0.2 + (0.35 if parsed.get("uses_torch") or parsed.get("uses_sklearn") else 0.0),
60
+ "web": 0.2 + (0.35 if parsed.get("uses_fastapi") or parsed.get("uses_flask") else 0.0) + (0.1 if parsed.get("route_decorators") else 0.0),
61
+ "general": 0.2,
62
+ }
63
+ if "fastapi" in code.lower():
64
+ scores["web"] += 0.1
65
+ if "pandas" in code.lower() or "numpy" in code.lower():
66
+ scores["data_science"] += 0.1
67
+ if "torch" in code.lower():
68
+ scores["ml_dl"] += 0.1
69
+ if "while" in code or "for" in code:
70
+ scores["dsa"] += 0.05
71
+ return {key: round(min(value, 0.99), 4) for key, value in scores.items()}
72
+
73
+ def analyze(self, request: AnalyzeCodeRequest) -> AnalyzeCodeResponse:
74
+ """Run the complete multi-domain analysis pipeline."""
75
+
76
+ started = time.perf_counter()
77
+ parsed = parse_code_structure(request.code)
78
+ complexity = estimate_complexity(parsed, request.code)
79
+ model_prediction = self.model.predict(request.code, request.context_window, parsed)
80
+ heuristic_scores = self._heuristic_domain_scores(parsed, request.code)
81
+
82
+ combined_scores = {}
83
+ for domain, heuristic_score in heuristic_scores.items():
84
+ model_score = float(model_prediction["domain_scores"].get(domain, 0.2))
85
+ combined_scores[domain] = round((0.6 * model_score) + (0.4 * heuristic_score), 4)
86
+
87
+ detected_domain = request.domain_hint if request.domain_hint != "auto" else max(combined_scores, key=combined_scores.get)
88
+ analyzer = self._analyzers.get(detected_domain)
89
+ domain_analysis = (
90
+ analyzer(request.code, parsed, complexity)
91
+ if analyzer is not None
92
+ else DomainAnalysis(
93
+ domain="general",
94
+ domain_score=0.6,
95
+ issues=[],
96
+ suggestions=["Add stronger domain-specific context for deeper analysis."],
97
+ highlights={},
98
+ )
99
+ )
100
+
101
+ lint_score = _lint_score(parsed)
102
+ score_breakdown = self.reward_service.compute(
103
+ ml_score=float(model_prediction["ml_quality_score"]),
104
+ domain_score=domain_analysis.domain_score,
105
+ lint_score=lint_score,
106
+ complexity_penalty=float(complexity["complexity_penalty"]),
107
+ )
108
+ static_analysis = StaticAnalysisSummary(
109
+ syntax_valid=bool(parsed["syntax_valid"]),
110
+ syntax_error=str(parsed["syntax_error"]),
111
+ cyclomatic_complexity=int(complexity["cyclomatic_complexity"]),
112
+ line_count=int(parsed["line_count"]),
113
+ max_loop_depth=int(parsed["max_loop_depth"]),
114
+ time_complexity=str(complexity["time_complexity"]),
115
+ space_complexity=str(complexity["space_complexity"]),
116
+ detected_imports=list(parsed["imports"]),
117
+ code_smells=list(parsed["code_smells"]),
118
+ )
119
+ improvement_plan = self.suggestion_service.build_improvement_plan(
120
+ domain_analysis=domain_analysis,
121
+ static_analysis=static_analysis,
122
+ )
123
+ summary = (
124
+ f"Detected `{detected_domain}` code with a model score of {score_breakdown.ml_score:.0%}, "
125
+ f"domain score {score_breakdown.domain_score:.0%}, and final reward {score_breakdown.reward:.0%}."
126
+ )
127
+ return AnalyzeCodeResponse(
128
+ detected_domain=detected_domain, # type: ignore[arg-type]
129
+ domain_confidences=combined_scores,
130
+ score_breakdown=score_breakdown,
131
+ static_analysis=static_analysis,
132
+ domain_analysis=domain_analysis,
133
+ improvement_plan=improvement_plan,
134
+ model_backend=str(model_prediction["backend_name"]),
135
+ model_id=str(model_prediction["model_id"]),
136
+ summary=summary,
137
+ context_window=request.context_window,
138
+ analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2),
139
+ )
services/reward_service.py CHANGED
@@ -1,38 +1,38 @@
1
- """Reward shaping logic for RL-ready code analysis scores."""
2
-
3
- from __future__ import annotations
4
-
5
- from schemas.response import ScoreBreakdown
6
-
7
-
8
- class RewardService:
9
- """Compute reward scores from model, domain, lint, and complexity signals."""
10
-
11
- def compute(self, *, ml_score: float, domain_score: float, lint_score: float, complexity_penalty: float) -> ScoreBreakdown:
12
- """Apply dynamic reward shaping based on quality, errors, and completion."""
13
-
14
- quality_signal = max(0.0, min(1.0, (0.45 * ml_score) + (0.3 * domain_score) + (0.25 * lint_score)))
15
- error_reduction_signal = max(0.0, min(1.0, lint_score - (0.6 * complexity_penalty)))
16
- completion_signal = max(0.0, min(1.0, (ml_score + domain_score + lint_score) / 3.0))
17
- reward = max(
18
- 0.0,
19
- min(
20
- 1.0,
21
- (0.35 * quality_signal)
22
- + (0.25 * completion_signal)
23
- + (0.2 * error_reduction_signal)
24
- + (0.1 * ml_score)
25
- + (0.1 * domain_score)
26
- - (0.15 * complexity_penalty),
27
- ),
28
- )
29
- return ScoreBreakdown(
30
- ml_score=round(ml_score, 4),
31
- domain_score=round(domain_score, 4),
32
- lint_score=round(lint_score, 4),
33
- complexity_penalty=round(complexity_penalty, 4),
34
- quality_signal=round(quality_signal, 4),
35
- error_reduction_signal=round(error_reduction_signal, 4),
36
- completion_signal=round(completion_signal, 4),
37
- reward=round(reward, 4),
38
- )
 
1
+ """Reward shaping logic for RL-ready code analysis scores."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from schemas.response import ScoreBreakdown
6
+
7
+
8
+ class RewardService:
9
+ """Compute reward scores from model, domain, lint, and complexity signals."""
10
+
11
+ def compute(self, *, ml_score: float, domain_score: float, lint_score: float, complexity_penalty: float) -> ScoreBreakdown:
12
+ """Apply dynamic reward shaping based on quality, errors, and completion."""
13
+
14
+ quality_signal = max(0.0, min(1.0, (0.45 * ml_score) + (0.3 * domain_score) + (0.25 * lint_score)))
15
+ error_reduction_signal = max(0.0, min(1.0, lint_score - (0.6 * complexity_penalty)))
16
+ completion_signal = max(0.0, min(1.0, (ml_score + domain_score + lint_score) / 3.0))
17
+ reward = max(
18
+ 0.0,
19
+ min(
20
+ 1.0,
21
+ (0.35 * quality_signal)
22
+ + (0.25 * completion_signal)
23
+ + (0.2 * error_reduction_signal)
24
+ + (0.1 * ml_score)
25
+ + (0.1 * domain_score)
26
+ - (0.15 * complexity_penalty),
27
+ ),
28
+ )
29
+ return ScoreBreakdown(
30
+ ml_score=round(ml_score, 4),
31
+ domain_score=round(domain_score, 4),
32
+ lint_score=round(lint_score, 4),
33
+ complexity_penalty=round(complexity_penalty, 4),
34
+ quality_signal=round(quality_signal, 4),
35
+ error_reduction_signal=round(error_reduction_signal, 4),
36
+ completion_signal=round(completion_signal, 4),
37
+ reward=round(reward, 4),
38
+ )
services/suggestion_service.py CHANGED
@@ -1,28 +1,28 @@
1
- """Suggestion and improvement-plan generation for analyzed code."""
2
-
3
- from __future__ import annotations
4
-
5
- from schemas.response import DomainAnalysis, StaticAnalysisSummary
6
-
7
-
8
- class SuggestionService:
9
- """Build high-signal improvement steps from analysis output."""
10
-
11
- def build_improvement_plan(self, *, domain_analysis: DomainAnalysis, static_analysis: StaticAnalysisSummary) -> list[str]:
12
- """Return a compact three-step plan optimized for developer action."""
13
-
14
- primary_issue = (
15
- domain_analysis.issues[0].description
16
- if domain_analysis.issues
17
- else "Stabilize correctness first and keep the public behavior explicit."
18
- )
19
-
20
- step_one = f"Step 1 - Correctness and safety: {primary_issue}"
21
- step_two = "Step 2 - Edge cases: test empty inputs, boundary values, malformed payloads, and failure-mode behavior explicitly."
22
- step_three = "Step 3 - Scalability: reduce repeated scans, lower cyclomatic complexity, and benchmark the path on realistic input sizes."
23
-
24
- if domain_analysis.suggestions:
25
- step_three = f"{step_three} Priority hint: {domain_analysis.suggestions[0]}"
26
- if not static_analysis.syntax_valid:
27
- step_one = f"Step 1 - Correctness and safety: fix the syntax error first ({static_analysis.syntax_error})."
28
- return [step_one, step_two, step_three]
 
1
+ """Suggestion and improvement-plan generation for analyzed code."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from schemas.response import DomainAnalysis, StaticAnalysisSummary
6
+
7
+
8
+ class SuggestionService:
9
+ """Build high-signal improvement steps from analysis output."""
10
+
11
+ def build_improvement_plan(self, *, domain_analysis: DomainAnalysis, static_analysis: StaticAnalysisSummary) -> list[str]:
12
+ """Return a compact three-step plan optimized for developer action."""
13
+
14
+ primary_issue = (
15
+ domain_analysis.issues[0].description
16
+ if domain_analysis.issues
17
+ else "Stabilize correctness first and keep the public behavior explicit."
18
+ )
19
+
20
+ step_one = f"Step 1 - Correctness and safety: {primary_issue}"
21
+ step_two = "Step 2 - Edge cases: test empty inputs, boundary values, malformed payloads, and failure-mode behavior explicitly."
22
+ step_three = "Step 3 - Scalability: reduce repeated scans, lower cyclomatic complexity, and benchmark the path on realistic input sizes."
23
+
24
+ if domain_analysis.suggestions:
25
+ step_three = f"{step_three} Priority hint: {domain_analysis.suggestions[0]}"
26
+ if not static_analysis.syntax_valid:
27
+ step_one = f"Step 1 - Correctness and safety: fix the syntax error first ({static_analysis.syntax_error})."
28
+ return [step_one, step_two, step_three]
triage.py CHANGED
@@ -1,473 +1,473 @@
1
- """PyTorch-backed triage pipeline for TorchReview Copilot."""
2
-
3
- from __future__ import annotations
4
-
5
- import ast
6
- import hashlib
7
- import os
8
- import re
9
- import time
10
- from functools import lru_cache
11
- from typing import List, Sequence
12
-
13
- import torch
14
- import torch.nn.functional as F
15
-
16
- try:
17
- from transformers import AutoModel, AutoTokenizer
18
- except Exception:
19
- AutoModel = None # type: ignore[assignment]
20
- AutoTokenizer = None # type: ignore[assignment]
21
-
22
- try:
23
- from .triage_catalog import build_examples, build_prototypes
24
- from .triage_models import (
25
- IssueLabel,
26
- PrototypeMatch,
27
- TriageExample,
28
- TriagePrototype,
29
- TriageResult,
30
- TriageSignal,
31
- )
32
- except ImportError:
33
- from triage_catalog import build_examples, build_prototypes
34
- from triage_models import (
35
- IssueLabel,
36
- PrototypeMatch,
37
- TriageExample,
38
- TriagePrototype,
39
- TriageResult,
40
- TriageSignal,
41
- )
42
-
43
-
44
- MODEL_ID = os.getenv("TRIAGE_MODEL_ID", "huggingface/CodeBERTa-small-v1")
45
- MODEL_MAX_LENGTH = int(os.getenv("TRIAGE_MODEL_MAX_LENGTH", "256"))
46
- LABELS: tuple[IssueLabel, ...] = ("syntax", "logic", "performance")
47
-
48
-
49
- class _LoopDepthVisitor(ast.NodeVisitor):
50
- """Track the maximum loop nesting depth in a code snippet."""
51
-
52
- def __init__(self) -> None:
53
- self.depth = 0
54
- self.max_depth = 0
55
-
56
- def _visit_loop(self, node: ast.AST) -> None:
57
- self.depth += 1
58
- self.max_depth = max(self.max_depth, self.depth)
59
- self.generic_visit(node)
60
- self.depth -= 1
61
-
62
- def visit_For(self, node: ast.For) -> None: # noqa: N802
63
- self._visit_loop(node)
64
-
65
- def visit_While(self, node: ast.While) -> None: # noqa: N802
66
- self._visit_loop(node)
67
-
68
- def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802
69
- self._visit_loop(node)
70
-
71
-
72
- class HashingEmbeddingBackend:
73
- """Deterministic torch-native fallback when pretrained weights are unavailable."""
74
-
75
- def __init__(self, dimensions: int = 96) -> None:
76
- self.dimensions = dimensions
77
- self.model_id = "hashed-token-fallback"
78
- self.backend_name = "hashed-token-fallback"
79
- self.notes = ["Using hashed torch embeddings because pretrained weights are unavailable."]
80
-
81
- def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
82
- rows = torch.zeros((len(texts), self.dimensions), dtype=torch.float32)
83
- for row_index, text in enumerate(texts):
84
- tokens = re.findall(r"[A-Za-z_]+|\d+|==|!=|<=|>=|\S", text.lower())[:512]
85
- if not tokens:
86
- rows[row_index, 0] = 1.0
87
- continue
88
- for token in tokens:
89
- digest = hashlib.md5(token.encode("utf-8")).hexdigest()
90
- bucket = int(digest[:8], 16) % self.dimensions
91
- sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0
92
- rows[row_index, bucket] += sign
93
- return F.normalize(rows + 1e-6, dim=1)
94
-
95
-
96
- class TransformersEmbeddingBackend:
97
- """Mean-pool CodeBERTa embeddings via torch + transformers."""
98
-
99
- def __init__(self, model_id: str = MODEL_ID, force_fallback: bool = False) -> None:
100
- self.model_id = model_id
101
- self.force_fallback = force_fallback
102
- self.backend_name = model_id
103
- self.notes: List[str] = []
104
- self._fallback = HashingEmbeddingBackend()
105
- self._tokenizer = None
106
- self._model = None
107
- self._load_error = ""
108
- if force_fallback:
109
- self.backend_name = self._fallback.backend_name
110
- self.notes = list(self._fallback.notes)
111
-
112
- def _ensure_loaded(self) -> None:
113
- if self.force_fallback or self._model is not None or self._load_error:
114
- return
115
- if AutoTokenizer is None or AutoModel is None:
116
- self._load_error = "transformers is not installed."
117
- else:
118
- try:
119
- self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
120
- self._model = AutoModel.from_pretrained(self.model_id)
121
- self._model.eval()
122
- self.notes.append(f"Loaded pretrained encoder `{self.model_id}` for inference.")
123
- except Exception as exc:
124
- self._load_error = f"{type(exc).__name__}: {exc}"
125
-
126
- if self._load_error:
127
- self.backend_name = self._fallback.backend_name
128
- self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {self._load_error}"]
129
-
130
- def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
131
- self._ensure_loaded()
132
- if self._model is None or self._tokenizer is None:
133
- return self._fallback.embed_texts(texts)
134
-
135
- encoded = self._tokenizer(
136
- list(texts),
137
- padding=True,
138
- truncation=True,
139
- max_length=MODEL_MAX_LENGTH,
140
- return_tensors="pt",
141
- )
142
- with torch.no_grad():
143
- outputs = self._model(**encoded)
144
- hidden_state = outputs.last_hidden_state
145
- mask = encoded["attention_mask"].unsqueeze(-1)
146
- pooled = (hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
147
- return F.normalize(pooled, dim=1)
148
-
149
-
150
- def _sanitize_text(value: str) -> str:
151
- text = (value or "").strip()
152
- return text[:4000]
153
-
154
-
155
- def _safe_softmax(scores: dict[IssueLabel, float]) -> dict[str, float]:
156
- tensor = torch.tensor([scores[label] for label in LABELS], dtype=torch.float32)
157
- probabilities = torch.softmax(tensor * 4.0, dim=0)
158
- return {label: round(float(probabilities[index]), 4) for index, label in enumerate(LABELS)}
159
-
160
-
161
- def _loop_depth(code: str) -> int:
162
- try:
163
- tree = ast.parse(code)
164
- except SyntaxError:
165
- return 0
166
- visitor = _LoopDepthVisitor()
167
- visitor.visit(tree)
168
- return visitor.max_depth
169
-
170
-
171
- def _repair_risk(label: IssueLabel, confidence: float, signal_count: int) -> str:
172
- base = {"syntax": 0.25, "logic": 0.55, "performance": 0.7}[label]
173
- if confidence < 0.55:
174
- base += 0.12
175
- if signal_count >= 4:
176
- base += 0.08
177
- if base < 0.4:
178
- return "low"
179
- if base < 0.72:
180
- return "medium"
181
- return "high"
182
-
183
-
184
- def _clamp_unit(value: float) -> float:
185
- return round(max(0.0, min(1.0, float(value))), 4)
186
-
187
-
188
- def _lint_score(code: str) -> float:
189
- stripped_lines = [line.rstrip("\n") for line in code.splitlines()]
190
- if not stripped_lines:
191
- return 0.2
192
-
193
- score = 1.0
194
- if any(len(line) > 88 for line in stripped_lines):
195
- score -= 0.15
196
- if any(line.rstrip() != line for line in stripped_lines):
197
- score -= 0.1
198
- if any("\t" in line for line in stripped_lines):
199
- score -= 0.1
200
- try:
201
- tree = ast.parse(code)
202
- functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)]
203
- if functions and not ast.get_docstring(functions[0]):
204
- score -= 0.08
205
- except SyntaxError:
206
- score -= 0.45
207
- return _clamp_unit(score)
208
-
209
-
210
- def _complexity_penalty(code: str) -> float:
211
- try:
212
- tree = ast.parse(code)
213
- except SyntaxError:
214
- return 0.95
215
- branch_nodes = sum(isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.Match)) for node in ast.walk(tree))
216
- loop_depth = _loop_depth(code)
217
- penalty = 0.1 + min(branch_nodes, 8) * 0.07 + min(loop_depth, 4) * 0.12
218
- return _clamp_unit(penalty)
219
-
220
-
221
- class CodeTriageEngine:
222
- """Combine static signals with PyTorch embeddings to classify code issues."""
223
-
224
- def __init__(
225
- self,
226
- *,
227
- backend: TransformersEmbeddingBackend | HashingEmbeddingBackend | None = None,
228
- prototypes: Sequence[TriagePrototype] | None = None,
229
- examples: Sequence[TriageExample] | None = None,
230
- ) -> None:
231
- self.backend = backend or TransformersEmbeddingBackend()
232
- self.prototypes = list(prototypes or build_prototypes())
233
- self.examples = list(examples or build_examples())
234
- self._prototype_matrix: torch.Tensor | None = None
235
- self._reference_code_matrix: torch.Tensor | None = None
236
-
237
- def example_map(self) -> dict[str, TriageExample]:
238
- """Return UI examples keyed by task id."""
239
-
240
- return {example.key: example for example in self.examples}
241
-
242
- def _build_document(self, code: str, traceback_text: str) -> str:
243
- trace = _sanitize_text(traceback_text) or "No traceback supplied."
244
- snippet = _sanitize_text(code) or "# No code supplied."
245
- return f"Candidate code:\n{snippet}\n\nObserved failure:\n{trace}\n"
246
-
247
- def _build_review_document(self, code: str, traceback_text: str, context_window: str) -> str:
248
- context = _sanitize_text(context_window) or "No additional context window supplied."
249
- return (
250
- f"{self._build_document(code, traceback_text)}\n"
251
- f"Context window:\n{context}\n"
252
- )
253
-
254
- def _prototype_embeddings(self) -> torch.Tensor:
255
- if self._prototype_matrix is None:
256
- reference_texts = [prototype.reference_text for prototype in self.prototypes]
257
- self._prototype_matrix = self.backend.embed_texts(reference_texts)
258
- return self._prototype_matrix
259
-
260
- def _reference_code_embeddings(self) -> torch.Tensor:
261
- if self._reference_code_matrix is None:
262
- reference_codes = [prototype.reference_code for prototype in self.prototypes]
263
- self._reference_code_matrix = self.backend.embed_texts(reference_codes)
264
- return self._reference_code_matrix
265
-
266
- def _extract_signals(self, code: str, traceback_text: str) -> tuple[list[TriageSignal], dict[IssueLabel, float], list[str]]:
267
- trace = (traceback_text or "").lower()
268
- heuristic_scores: dict[IssueLabel, float] = {label: 0.15 for label in LABELS}
269
- signals: list[TriageSignal] = []
270
- notes: list[str] = []
271
-
272
- try:
273
- ast.parse(code)
274
- signals.append(
275
- TriageSignal(
276
- name="syntax_parse",
277
- value="passes",
278
- impact="syntax",
279
- weight=0.1,
280
- evidence="Python AST parsing succeeded.",
281
- )
282
- )
283
- heuristic_scores["logic"] += 0.05
284
- except SyntaxError as exc:
285
- evidence = f"{exc.msg} at line {exc.lineno}"
286
- signals.append(
287
- TriageSignal(
288
- name="syntax_parse",
289
- value="fails",
290
- impact="syntax",
291
- weight=0.95,
292
- evidence=evidence,
293
- )
294
- )
295
- heuristic_scores["syntax"] += 0.85
296
- notes.append(f"Parser failure detected: {evidence}")
297
-
298
- if any(token in trace for token in ("syntaxerror", "indentationerror", "expected ':'")):
299
- signals.append(
300
- TriageSignal(
301
- name="traceback_keyword",
302
- value="syntaxerror",
303
- impact="syntax",
304
- weight=0.8,
305
- evidence="Traceback contains a parser error.",
306
- )
307
- )
308
- heuristic_scores["syntax"] += 0.55
309
-
310
- if any(token in trace for token in ("assertionerror", "expected:", "actual:", "boundary", "missing", "incorrect")):
311
- signals.append(
312
- TriageSignal(
313
- name="test_failure_signal",
314
- value="assertion-style failure",
315
- impact="logic",
316
- weight=0.7,
317
- evidence="Failure text points to behavioral mismatch instead of parser issues.",
318
- )
319
- )
320
- heuristic_scores["logic"] += 0.55
321
-
322
- if any(token in trace for token in ("timeout", "benchmark", "slow", "latency", "performance", "profiler")):
323
- signals.append(
324
- TriageSignal(
325
- name="performance_trace",
326
- value="latency regression",
327
- impact="performance",
328
- weight=0.85,
329
- evidence="Traceback mentions benchmark or latency pressure.",
330
- )
331
- )
332
- heuristic_scores["performance"] += 0.7
333
-
334
- loop_depth = _loop_depth(code)
335
- if loop_depth >= 2:
336
- signals.append(
337
- TriageSignal(
338
- name="loop_depth",
339
- value=str(loop_depth),
340
- impact="performance",
341
- weight=0.65,
342
- evidence="Nested iteration increases runtime risk on larger fixtures.",
343
- )
344
- )
345
- heuristic_scores["performance"] += 0.35
346
-
347
- if "Counter(" in code or "defaultdict(" in code or "set(" in code:
348
- heuristic_scores["performance"] += 0.05
349
-
350
- if "return sessions" in code and "sessions.append" not in code:
351
- signals.append(
352
- TriageSignal(
353
- name="state_update_gap",
354
- value="possible missing final append",
355
- impact="logic",
356
- weight=0.45,
357
- evidence="A collection is returned without an obvious final state flush.",
358
- )
359
- )
360
- heuristic_scores["logic"] += 0.18
361
-
362
- return signals, heuristic_scores, notes
363
-
364
- def _nearest_match(self, embedding: torch.Tensor) -> tuple[TriagePrototype, float, dict[str, float]]:
365
- similarities = torch.matmul(embedding, self._prototype_embeddings().T)[0]
366
- indexed_scores = {
367
- self.prototypes[index].task_id: round(float((similarities[index] + 1.0) / 2.0), 4)
368
- for index in range(len(self.prototypes))
369
- }
370
- best_index = int(torch.argmax(similarities).item())
371
- best_prototype = self.prototypes[best_index]
372
- best_similarity = float((similarities[best_index] + 1.0) / 2.0)
373
- return best_prototype, best_similarity, indexed_scores
374
-
375
- def _repair_plan(self, label: IssueLabel, matched: TriagePrototype, context_window: str) -> list[str]:
376
- context = _sanitize_text(context_window)
377
- step_one = {
378
- "syntax": "Step 1 - Syntax checking and bug fixes: resolve the parser break before touching behavior, then align the function with the expected contract.",
379
- "logic": "Step 1 - Syntax checking and bug fixes: confirm the code parses cleanly, then patch the failing branch or state update causing the incorrect result.",
380
- "performance": "Step 1 - Syntax checking and bug fixes: keep the implementation correct first, then isolate the slow section without changing external behavior.",
381
- }[label]
382
- step_two = (
383
- "Step 2 - Edge case handling: verify empty input, boundary values, missing fields, and final-state flush behavior "
384
- f"against the known pattern `{matched.title}`."
385
- )
386
- step_three = (
387
- "Step 3 - Scalability of code: remove repeated full scans, prefer linear-time data structures, "
388
- "and benchmark the path on a production-like fixture."
389
- )
390
- if context:
391
- step_two = f"{step_two} Context window to preserve: {context}"
392
- return [step_one, step_two, step_three]
393
-
394
- def _reference_quality_score(self, code: str, matched: TriagePrototype) -> float:
395
- candidate = self.backend.embed_texts([_sanitize_text(code) or "# empty"])
396
- match_index = next(index for index, prototype in enumerate(self.prototypes) if prototype.task_id == matched.task_id)
397
- reference = self._reference_code_embeddings()[match_index : match_index + 1]
398
- score = float(torch.matmul(candidate, reference.T)[0][0].item())
399
- return _clamp_unit((score + 1.0) / 2.0)
400
-
401
- def triage(self, code: str, traceback_text: str = "", context_window: str = "") -> TriageResult:
402
- """Run the full triage pipeline on code plus optional failure context."""
403
-
404
- started = time.perf_counter()
405
- document = self._build_review_document(code, traceback_text, context_window)
406
- signals, heuristic_scores, notes = self._extract_signals(code, traceback_text)
407
-
408
- candidate_embedding = self.backend.embed_texts([document])
409
- matched, matched_similarity, prototype_scores = self._nearest_match(candidate_embedding)
410
-
411
- label_similarity = {label: 0.18 for label in LABELS}
412
- for prototype in self.prototypes:
413
- label_similarity[prototype.label] = max(
414
- label_similarity[prototype.label],
415
- prototype_scores[prototype.task_id],
416
- )
417
-
418
- combined_scores = {
419
- label: 0.72 * label_similarity[label] + 0.28 * heuristic_scores[label]
420
- for label in LABELS
421
- }
422
- confidence_scores = _safe_softmax(combined_scores)
423
- issue_label = max(LABELS, key=lambda label: confidence_scores[label])
424
- top_confidence = confidence_scores[issue_label]
425
-
426
- top_signal = signals[0].evidence if signals else "Model similarity dominated the decision."
427
- ml_quality_score = self._reference_quality_score(code, matched)
428
- lint_score = _lint_score(code)
429
- complexity_penalty = _complexity_penalty(code)
430
- reward_score = _clamp_unit((0.5 * ml_quality_score) + (0.3 * lint_score) - (0.2 * complexity_penalty))
431
- summary = (
432
- f"Detected a {issue_label} issue with {top_confidence:.0%} confidence. "
433
- f"The closest known failure pattern is `{matched.title}`, which indicates {matched.summary.lower()}. "
434
- f"Predicted quality score is {ml_quality_score:.0%} with an RL-ready reward of {reward_score:.0%}."
435
- )
436
- suggested_next_action = {
437
- "syntax": "Fix the parser error first, then rerun validation before changing behavior.",
438
- "logic": "Step through the smallest failing case and confirm the final branch/update behavior.",
439
- "performance": "Replace repeated full-list scans with a linear-time aggregation strategy, then benchmark it.",
440
- }[issue_label]
441
-
442
- return TriageResult(
443
- issue_label=issue_label,
444
- confidence_scores=confidence_scores,
445
- repair_risk=_repair_risk(issue_label, top_confidence, len(signals)),
446
- ml_quality_score=ml_quality_score,
447
- lint_score=lint_score,
448
- complexity_penalty=complexity_penalty,
449
- reward_score=reward_score,
450
- summary=summary,
451
- matched_pattern=PrototypeMatch(
452
- task_id=matched.task_id,
453
- title=matched.title,
454
- label=matched.label,
455
- similarity=round(matched_similarity, 4),
456
- summary=matched.summary,
457
- rationale=top_signal,
458
- ),
459
- repair_plan=self._repair_plan(issue_label, matched, context_window),
460
- suggested_next_action=suggested_next_action,
461
- extracted_signals=signals,
462
- model_backend=self.backend.backend_name,
463
- model_id=self.backend.model_id,
464
- inference_notes=list(self.backend.notes) + notes,
465
- analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2),
466
- )
467
-
468
-
469
- @lru_cache(maxsize=1)
470
- def get_default_engine() -> CodeTriageEngine:
471
- """Return a cached triage engine for the running process."""
472
-
473
- return CodeTriageEngine()
 
1
+ """PyTorch-backed triage pipeline for TorchReview Copilot."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import hashlib
7
+ import os
8
+ import re
9
+ import time
10
+ from functools import lru_cache
11
+ from typing import List, Sequence
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ try:
17
+ from transformers import AutoModel, AutoTokenizer
18
+ except Exception:
19
+ AutoModel = None # type: ignore[assignment]
20
+ AutoTokenizer = None # type: ignore[assignment]
21
+
22
+ try:
23
+ from .triage_catalog import build_examples, build_prototypes
24
+ from .triage_models import (
25
+ IssueLabel,
26
+ PrototypeMatch,
27
+ TriageExample,
28
+ TriagePrototype,
29
+ TriageResult,
30
+ TriageSignal,
31
+ )
32
+ except ImportError:
33
+ from triage_catalog import build_examples, build_prototypes
34
+ from triage_models import (
35
+ IssueLabel,
36
+ PrototypeMatch,
37
+ TriageExample,
38
+ TriagePrototype,
39
+ TriageResult,
40
+ TriageSignal,
41
+ )
42
+
43
+
44
+ MODEL_ID = os.getenv("TRIAGE_MODEL_ID", "huggingface/CodeBERTa-small-v1")
45
+ MODEL_MAX_LENGTH = int(os.getenv("TRIAGE_MODEL_MAX_LENGTH", "256"))
46
+ LABELS: tuple[IssueLabel, ...] = ("syntax", "logic", "performance")
47
+
48
+
49
+ class _LoopDepthVisitor(ast.NodeVisitor):
50
+ """Track the maximum loop nesting depth in a code snippet."""
51
+
52
+ def __init__(self) -> None:
53
+ self.depth = 0
54
+ self.max_depth = 0
55
+
56
+ def _visit_loop(self, node: ast.AST) -> None:
57
+ self.depth += 1
58
+ self.max_depth = max(self.max_depth, self.depth)
59
+ self.generic_visit(node)
60
+ self.depth -= 1
61
+
62
+ def visit_For(self, node: ast.For) -> None: # noqa: N802
63
+ self._visit_loop(node)
64
+
65
+ def visit_While(self, node: ast.While) -> None: # noqa: N802
66
+ self._visit_loop(node)
67
+
68
+ def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802
69
+ self._visit_loop(node)
70
+
71
+
72
+ class HashingEmbeddingBackend:
73
+ """Deterministic torch-native fallback when pretrained weights are unavailable."""
74
+
75
+ def __init__(self, dimensions: int = 96) -> None:
76
+ self.dimensions = dimensions
77
+ self.model_id = "hashed-token-fallback"
78
+ self.backend_name = "hashed-token-fallback"
79
+ self.notes = ["Using hashed torch embeddings because pretrained weights are unavailable."]
80
+
81
+ def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
82
+ rows = torch.zeros((len(texts), self.dimensions), dtype=torch.float32)
83
+ for row_index, text in enumerate(texts):
84
+ tokens = re.findall(r"[A-Za-z_]+|\d+|==|!=|<=|>=|\S", text.lower())[:512]
85
+ if not tokens:
86
+ rows[row_index, 0] = 1.0
87
+ continue
88
+ for token in tokens:
89
+ digest = hashlib.md5(token.encode("utf-8")).hexdigest()
90
+ bucket = int(digest[:8], 16) % self.dimensions
91
+ sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0
92
+ rows[row_index, bucket] += sign
93
+ return F.normalize(rows + 1e-6, dim=1)
94
+
95
+
96
+ class TransformersEmbeddingBackend:
97
+ """Mean-pool CodeBERTa embeddings via torch + transformers."""
98
+
99
+ def __init__(self, model_id: str = MODEL_ID, force_fallback: bool = False) -> None:
100
+ self.model_id = model_id
101
+ self.force_fallback = force_fallback
102
+ self.backend_name = model_id
103
+ self.notes: List[str] = []
104
+ self._fallback = HashingEmbeddingBackend()
105
+ self._tokenizer = None
106
+ self._model = None
107
+ self._load_error = ""
108
+ if force_fallback:
109
+ self.backend_name = self._fallback.backend_name
110
+ self.notes = list(self._fallback.notes)
111
+
112
+ def _ensure_loaded(self) -> None:
113
+ if self.force_fallback or self._model is not None or self._load_error:
114
+ return
115
+ if AutoTokenizer is None or AutoModel is None:
116
+ self._load_error = "transformers is not installed."
117
+ else:
118
+ try:
119
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
120
+ self._model = AutoModel.from_pretrained(self.model_id)
121
+ self._model.eval()
122
+ self.notes.append(f"Loaded pretrained encoder `{self.model_id}` for inference.")
123
+ except Exception as exc:
124
+ self._load_error = f"{type(exc).__name__}: {exc}"
125
+
126
+ if self._load_error:
127
+ self.backend_name = self._fallback.backend_name
128
+ self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {self._load_error}"]
129
+
130
+ def embed_texts(self, texts: Sequence[str]) -> torch.Tensor:
131
+ self._ensure_loaded()
132
+ if self._model is None or self._tokenizer is None:
133
+ return self._fallback.embed_texts(texts)
134
+
135
+ encoded = self._tokenizer(
136
+ list(texts),
137
+ padding=True,
138
+ truncation=True,
139
+ max_length=MODEL_MAX_LENGTH,
140
+ return_tensors="pt",
141
+ )
142
+ with torch.no_grad():
143
+ outputs = self._model(**encoded)
144
+ hidden_state = outputs.last_hidden_state
145
+ mask = encoded["attention_mask"].unsqueeze(-1)
146
+ pooled = (hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
147
+ return F.normalize(pooled, dim=1)
148
+
149
+
150
+ def _sanitize_text(value: str) -> str:
151
+ text = (value or "").strip()
152
+ return text[:4000]
153
+
154
+
155
+ def _safe_softmax(scores: dict[IssueLabel, float]) -> dict[str, float]:
156
+ tensor = torch.tensor([scores[label] for label in LABELS], dtype=torch.float32)
157
+ probabilities = torch.softmax(tensor * 4.0, dim=0)
158
+ return {label: round(float(probabilities[index]), 4) for index, label in enumerate(LABELS)}
159
+
160
+
161
+ def _loop_depth(code: str) -> int:
162
+ try:
163
+ tree = ast.parse(code)
164
+ except SyntaxError:
165
+ return 0
166
+ visitor = _LoopDepthVisitor()
167
+ visitor.visit(tree)
168
+ return visitor.max_depth
169
+
170
+
171
+ def _repair_risk(label: IssueLabel, confidence: float, signal_count: int) -> str:
172
+ base = {"syntax": 0.25, "logic": 0.55, "performance": 0.7}[label]
173
+ if confidence < 0.55:
174
+ base += 0.12
175
+ if signal_count >= 4:
176
+ base += 0.08
177
+ if base < 0.4:
178
+ return "low"
179
+ if base < 0.72:
180
+ return "medium"
181
+ return "high"
182
+
183
+
184
+ def _clamp_unit(value: float) -> float:
185
+ return round(max(0.0, min(1.0, float(value))), 4)
186
+
187
+
188
+ def _lint_score(code: str) -> float:
189
+ stripped_lines = [line.rstrip("\n") for line in code.splitlines()]
190
+ if not stripped_lines:
191
+ return 0.2
192
+
193
+ score = 1.0
194
+ if any(len(line) > 88 for line in stripped_lines):
195
+ score -= 0.15
196
+ if any(line.rstrip() != line for line in stripped_lines):
197
+ score -= 0.1
198
+ if any("\t" in line for line in stripped_lines):
199
+ score -= 0.1
200
+ try:
201
+ tree = ast.parse(code)
202
+ functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)]
203
+ if functions and not ast.get_docstring(functions[0]):
204
+ score -= 0.08
205
+ except SyntaxError:
206
+ score -= 0.45
207
+ return _clamp_unit(score)
208
+
209
+
210
+ def _complexity_penalty(code: str) -> float:
211
+ try:
212
+ tree = ast.parse(code)
213
+ except SyntaxError:
214
+ return 0.95
215
+ branch_nodes = sum(isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.Match)) for node in ast.walk(tree))
216
+ loop_depth = _loop_depth(code)
217
+ penalty = 0.1 + min(branch_nodes, 8) * 0.07 + min(loop_depth, 4) * 0.12
218
+ return _clamp_unit(penalty)
219
+
220
+
221
+ class CodeTriageEngine:
222
+ """Combine static signals with PyTorch embeddings to classify code issues."""
223
+
224
+ def __init__(
225
+ self,
226
+ *,
227
+ backend: TransformersEmbeddingBackend | HashingEmbeddingBackend | None = None,
228
+ prototypes: Sequence[TriagePrototype] | None = None,
229
+ examples: Sequence[TriageExample] | None = None,
230
+ ) -> None:
231
+ self.backend = backend or TransformersEmbeddingBackend()
232
+ self.prototypes = list(prototypes or build_prototypes())
233
+ self.examples = list(examples or build_examples())
234
+ self._prototype_matrix: torch.Tensor | None = None
235
+ self._reference_code_matrix: torch.Tensor | None = None
236
+
237
+ def example_map(self) -> dict[str, TriageExample]:
238
+ """Return UI examples keyed by task id."""
239
+
240
+ return {example.key: example for example in self.examples}
241
+
242
+ def _build_document(self, code: str, traceback_text: str) -> str:
243
+ trace = _sanitize_text(traceback_text) or "No traceback supplied."
244
+ snippet = _sanitize_text(code) or "# No code supplied."
245
+ return f"Candidate code:\n{snippet}\n\nObserved failure:\n{trace}\n"
246
+
247
+ def _build_review_document(self, code: str, traceback_text: str, context_window: str) -> str:
248
+ context = _sanitize_text(context_window) or "No additional context window supplied."
249
+ return (
250
+ f"{self._build_document(code, traceback_text)}\n"
251
+ f"Context window:\n{context}\n"
252
+ )
253
+
254
+ def _prototype_embeddings(self) -> torch.Tensor:
255
+ if self._prototype_matrix is None:
256
+ reference_texts = [prototype.reference_text for prototype in self.prototypes]
257
+ self._prototype_matrix = self.backend.embed_texts(reference_texts)
258
+ return self._prototype_matrix
259
+
260
+ def _reference_code_embeddings(self) -> torch.Tensor:
261
+ if self._reference_code_matrix is None:
262
+ reference_codes = [prototype.reference_code for prototype in self.prototypes]
263
+ self._reference_code_matrix = self.backend.embed_texts(reference_codes)
264
+ return self._reference_code_matrix
265
+
266
+ def _extract_signals(self, code: str, traceback_text: str) -> tuple[list[TriageSignal], dict[IssueLabel, float], list[str]]:
267
+ trace = (traceback_text or "").lower()
268
+ heuristic_scores: dict[IssueLabel, float] = {label: 0.15 for label in LABELS}
269
+ signals: list[TriageSignal] = []
270
+ notes: list[str] = []
271
+
272
+ try:
273
+ ast.parse(code)
274
+ signals.append(
275
+ TriageSignal(
276
+ name="syntax_parse",
277
+ value="passes",
278
+ impact="syntax",
279
+ weight=0.1,
280
+ evidence="Python AST parsing succeeded.",
281
+ )
282
+ )
283
+ heuristic_scores["logic"] += 0.05
284
+ except SyntaxError as exc:
285
+ evidence = f"{exc.msg} at line {exc.lineno}"
286
+ signals.append(
287
+ TriageSignal(
288
+ name="syntax_parse",
289
+ value="fails",
290
+ impact="syntax",
291
+ weight=0.95,
292
+ evidence=evidence,
293
+ )
294
+ )
295
+ heuristic_scores["syntax"] += 0.85
296
+ notes.append(f"Parser failure detected: {evidence}")
297
+
298
+ if any(token in trace for token in ("syntaxerror", "indentationerror", "expected ':'")):
299
+ signals.append(
300
+ TriageSignal(
301
+ name="traceback_keyword",
302
+ value="syntaxerror",
303
+ impact="syntax",
304
+ weight=0.8,
305
+ evidence="Traceback contains a parser error.",
306
+ )
307
+ )
308
+ heuristic_scores["syntax"] += 0.55
309
+
310
+ if any(token in trace for token in ("assertionerror", "expected:", "actual:", "boundary", "missing", "incorrect")):
311
+ signals.append(
312
+ TriageSignal(
313
+ name="test_failure_signal",
314
+ value="assertion-style failure",
315
+ impact="logic",
316
+ weight=0.7,
317
+ evidence="Failure text points to behavioral mismatch instead of parser issues.",
318
+ )
319
+ )
320
+ heuristic_scores["logic"] += 0.55
321
+
322
+ if any(token in trace for token in ("timeout", "benchmark", "slow", "latency", "performance", "profiler")):
323
+ signals.append(
324
+ TriageSignal(
325
+ name="performance_trace",
326
+ value="latency regression",
327
+ impact="performance",
328
+ weight=0.85,
329
+ evidence="Traceback mentions benchmark or latency pressure.",
330
+ )
331
+ )
332
+ heuristic_scores["performance"] += 0.7
333
+
334
+ loop_depth = _loop_depth(code)
335
+ if loop_depth >= 2:
336
+ signals.append(
337
+ TriageSignal(
338
+ name="loop_depth",
339
+ value=str(loop_depth),
340
+ impact="performance",
341
+ weight=0.65,
342
+ evidence="Nested iteration increases runtime risk on larger fixtures.",
343
+ )
344
+ )
345
+ heuristic_scores["performance"] += 0.35
346
+
347
+ if "Counter(" in code or "defaultdict(" in code or "set(" in code:
348
+ heuristic_scores["performance"] += 0.05
349
+
350
+ if "return sessions" in code and "sessions.append" not in code:
351
+ signals.append(
352
+ TriageSignal(
353
+ name="state_update_gap",
354
+ value="possible missing final append",
355
+ impact="logic",
356
+ weight=0.45,
357
+ evidence="A collection is returned without an obvious final state flush.",
358
+ )
359
+ )
360
+ heuristic_scores["logic"] += 0.18
361
+
362
+ return signals, heuristic_scores, notes
363
+
364
+ def _nearest_match(self, embedding: torch.Tensor) -> tuple[TriagePrototype, float, dict[str, float]]:
365
+ similarities = torch.matmul(embedding, self._prototype_embeddings().T)[0]
366
+ indexed_scores = {
367
+ self.prototypes[index].task_id: round(float((similarities[index] + 1.0) / 2.0), 4)
368
+ for index in range(len(self.prototypes))
369
+ }
370
+ best_index = int(torch.argmax(similarities).item())
371
+ best_prototype = self.prototypes[best_index]
372
+ best_similarity = float((similarities[best_index] + 1.0) / 2.0)
373
+ return best_prototype, best_similarity, indexed_scores
374
+
375
+ def _repair_plan(self, label: IssueLabel, matched: TriagePrototype, context_window: str) -> list[str]:
376
+ context = _sanitize_text(context_window)
377
+ step_one = {
378
+ "syntax": "Step 1 - Syntax checking and bug fixes: resolve the parser break before touching behavior, then align the function with the expected contract.",
379
+ "logic": "Step 1 - Syntax checking and bug fixes: confirm the code parses cleanly, then patch the failing branch or state update causing the incorrect result.",
380
+ "performance": "Step 1 - Syntax checking and bug fixes: keep the implementation correct first, then isolate the slow section without changing external behavior.",
381
+ }[label]
382
+ step_two = (
383
+ "Step 2 - Edge case handling: verify empty input, boundary values, missing fields, and final-state flush behavior "
384
+ f"against the known pattern `{matched.title}`."
385
+ )
386
+ step_three = (
387
+ "Step 3 - Scalability of code: remove repeated full scans, prefer linear-time data structures, "
388
+ "and benchmark the path on a production-like fixture."
389
+ )
390
+ if context:
391
+ step_two = f"{step_two} Context window to preserve: {context}"
392
+ return [step_one, step_two, step_three]
393
+
394
+ def _reference_quality_score(self, code: str, matched: TriagePrototype) -> float:
395
+ candidate = self.backend.embed_texts([_sanitize_text(code) or "# empty"])
396
+ match_index = next(index for index, prototype in enumerate(self.prototypes) if prototype.task_id == matched.task_id)
397
+ reference = self._reference_code_embeddings()[match_index : match_index + 1]
398
+ score = float(torch.matmul(candidate, reference.T)[0][0].item())
399
+ return _clamp_unit((score + 1.0) / 2.0)
400
+
401
+ def triage(self, code: str, traceback_text: str = "", context_window: str = "") -> TriageResult:
402
+ """Run the full triage pipeline on code plus optional failure context."""
403
+
404
+ started = time.perf_counter()
405
+ document = self._build_review_document(code, traceback_text, context_window)
406
+ signals, heuristic_scores, notes = self._extract_signals(code, traceback_text)
407
+
408
+ candidate_embedding = self.backend.embed_texts([document])
409
+ matched, matched_similarity, prototype_scores = self._nearest_match(candidate_embedding)
410
+
411
+ label_similarity = {label: 0.18 for label in LABELS}
412
+ for prototype in self.prototypes:
413
+ label_similarity[prototype.label] = max(
414
+ label_similarity[prototype.label],
415
+ prototype_scores[prototype.task_id],
416
+ )
417
+
418
+ combined_scores = {
419
+ label: 0.72 * label_similarity[label] + 0.28 * heuristic_scores[label]
420
+ for label in LABELS
421
+ }
422
+ confidence_scores = _safe_softmax(combined_scores)
423
+ issue_label = max(LABELS, key=lambda label: confidence_scores[label])
424
+ top_confidence = confidence_scores[issue_label]
425
+
426
+ top_signal = signals[0].evidence if signals else "Model similarity dominated the decision."
427
+ ml_quality_score = self._reference_quality_score(code, matched)
428
+ lint_score = _lint_score(code)
429
+ complexity_penalty = _complexity_penalty(code)
430
+ reward_score = _clamp_unit((0.5 * ml_quality_score) + (0.3 * lint_score) - (0.2 * complexity_penalty))
431
+ summary = (
432
+ f"Detected a {issue_label} issue with {top_confidence:.0%} confidence. "
433
+ f"The closest known failure pattern is `{matched.title}`, which indicates {matched.summary.lower()}. "
434
+ f"Predicted quality score is {ml_quality_score:.0%} with an RL-ready reward of {reward_score:.0%}."
435
+ )
436
+ suggested_next_action = {
437
+ "syntax": "Fix the parser error first, then rerun validation before changing behavior.",
438
+ "logic": "Step through the smallest failing case and confirm the final branch/update behavior.",
439
+ "performance": "Replace repeated full-list scans with a linear-time aggregation strategy, then benchmark it.",
440
+ }[issue_label]
441
+
442
+ return TriageResult(
443
+ issue_label=issue_label,
444
+ confidence_scores=confidence_scores,
445
+ repair_risk=_repair_risk(issue_label, top_confidence, len(signals)),
446
+ ml_quality_score=ml_quality_score,
447
+ lint_score=lint_score,
448
+ complexity_penalty=complexity_penalty,
449
+ reward_score=reward_score,
450
+ summary=summary,
451
+ matched_pattern=PrototypeMatch(
452
+ task_id=matched.task_id,
453
+ title=matched.title,
454
+ label=matched.label,
455
+ similarity=round(matched_similarity, 4),
456
+ summary=matched.summary,
457
+ rationale=top_signal,
458
+ ),
459
+ repair_plan=self._repair_plan(issue_label, matched, context_window),
460
+ suggested_next_action=suggested_next_action,
461
+ extracted_signals=signals,
462
+ model_backend=self.backend.backend_name,
463
+ model_id=self.backend.model_id,
464
+ inference_notes=list(self.backend.notes) + notes,
465
+ analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2),
466
+ )
467
+
468
+
469
+ @lru_cache(maxsize=1)
470
+ def get_default_engine() -> CodeTriageEngine:
471
+ """Return a cached triage engine for the running process."""
472
+
473
+ return CodeTriageEngine()
triage_catalog.py CHANGED
@@ -1,134 +1,134 @@
1
- """Curated prototypes and example inputs for TorchReview Copilot."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Dict, List
6
-
7
- try:
8
- from .triage_models import IssueLabel, TriageExample, TriagePrototype
9
- from .tasks import list_tasks
10
- except ImportError:
11
- from triage_models import IssueLabel, TriageExample, TriagePrototype
12
- from tasks import list_tasks
13
-
14
-
15
- TASK_KIND_TO_LABEL: Dict[str, IssueLabel] = {
16
- "syntax_fix": "syntax",
17
- "bug_fix": "logic",
18
- "optimization": "performance",
19
- }
20
-
21
- TRACEBACK_BY_TASK_ID: Dict[str, str] = {
22
- "syntax_fix_invoice_totals": (
23
- "Traceback (most recent call last):\n"
24
- " File \"services/billing/reconciliation.py\", line 3\n"
25
- " for record in records\n"
26
- " ^\n"
27
- "SyntaxError: expected ':'"
28
- ),
29
- "bug_fix_session_windows": (
30
- "AssertionError: collapse_sessions([{'minute': 1}, {'minute': 3}, {'minute': 8}], 4)\n"
31
- "Expected: [(1, 3), (8, 8)]\n"
32
- "Actual: [(1, 8)]\n"
33
- "Boundary handling merges the final session instead of starting a new one."
34
- ),
35
- "optimization_rank_active_users": (
36
- "BenchmarkWarning: rank_active_users exceeded the 450ms budget on a nightly export fixture.\n"
37
- "Profiler hint: repeated scans over the full event list and nested loops dominate runtime."
38
- ),
39
- }
40
-
41
- SUMMARY_BY_TASK_ID: Dict[str, str] = {
42
- "syntax_fix_invoice_totals": "Broken parser state in a billing helper blocks reconciliation jobs.",
43
- "bug_fix_session_windows": "Session-boundary logic fails on inclusive idle-timeout edges.",
44
- "optimization_rank_active_users": "A nightly ranking job is correct on small fixtures but too slow at production scale.",
45
- }
46
-
47
- CONTEXT_BY_TASK_ID: Dict[str, str] = {
48
- "syntax_fix_invoice_totals": (
49
- "Context window: this helper runs in an end-of-day billing reconciliation job. "
50
- "Keep the public function signature intact and restore correct totals for mixed integer/string inputs."
51
- ),
52
- "bug_fix_session_windows": (
53
- "Context window: this function groups sorted product analytics events into sessions for retention dashboards. "
54
- "Boundary behavior must stay deterministic because downstream reports depend on it."
55
- ),
56
- "optimization_rank_active_users": (
57
- "Context window: this pipeline feeds a nightly export on a small CPU instance. "
58
- "Maintain identical output ordering while improving scalability on larger event volumes."
59
- ),
60
- }
61
-
62
-
63
- def _prototype_text(
64
- task_id: str,
65
- title: str,
66
- description: str,
67
- repo_summary: str,
68
- goal: str,
69
- visible_tests: List[str],
70
- starter_code: str,
71
- traceback_text: str,
72
- ) -> str:
73
- visible = "\n".join(f"- {item}" for item in visible_tests) or "- none"
74
- return (
75
- f"Title: {title}\n"
76
- f"Problem: {description}\n"
77
- f"Repo context: {repo_summary}\n"
78
- f"Goal: {goal}\n"
79
- f"Observed failure:\n{traceback_text}\n"
80
- f"Visible checks:\n{visible}\n"
81
- f"Candidate code:\n{starter_code}\n"
82
- f"Task id: {task_id}\n"
83
- )
84
-
85
-
86
- def build_examples() -> List[TriageExample]:
87
- """Create stable UI examples from the task catalog."""
88
-
89
- examples: List[TriageExample] = []
90
- for task in list_tasks():
91
- label = TASK_KIND_TO_LABEL[task.task_kind]
92
- examples.append(
93
- TriageExample(
94
- key=task.task_id,
95
- title=task.title,
96
- label=label,
97
- summary=SUMMARY_BY_TASK_ID[task.task_id],
98
- code=task.starter_code,
99
- traceback_text=TRACEBACK_BY_TASK_ID[task.task_id],
100
- context_window=CONTEXT_BY_TASK_ID[task.task_id],
101
- task_id=task.task_id,
102
- )
103
- )
104
- return examples
105
-
106
-
107
- def build_prototypes() -> List[TriagePrototype]:
108
- """Build canonical triage prototypes from the OpenEnv tasks."""
109
-
110
- prototypes: List[TriagePrototype] = []
111
- for task in list_tasks():
112
- traceback_text = TRACEBACK_BY_TASK_ID[task.task_id]
113
- prototypes.append(
114
- TriagePrototype(
115
- task_id=task.task_id,
116
- title=task.title,
117
- label=TASK_KIND_TO_LABEL[task.task_kind],
118
- summary=SUMMARY_BY_TASK_ID[task.task_id],
119
- reference_text=_prototype_text(
120
- task.task_id,
121
- task.title,
122
- task.task_description,
123
- task.repo_summary,
124
- task.goal,
125
- list(task.visible_tests),
126
- task.reference_code,
127
- traceback_text,
128
- ),
129
- starter_code=task.starter_code,
130
- reference_code=task.reference_code,
131
- traceback_text=traceback_text,
132
- )
133
- )
134
- return prototypes
 
1
+ """Curated prototypes and example inputs for TorchReview Copilot."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List
6
+
7
+ try:
8
+ from .triage_models import IssueLabel, TriageExample, TriagePrototype
9
+ from .tasks import list_tasks
10
+ except ImportError:
11
+ from triage_models import IssueLabel, TriageExample, TriagePrototype
12
+ from tasks import list_tasks
13
+
14
+
15
+ TASK_KIND_TO_LABEL: Dict[str, IssueLabel] = {
16
+ "syntax_fix": "syntax",
17
+ "bug_fix": "logic",
18
+ "optimization": "performance",
19
+ }
20
+
21
+ TRACEBACK_BY_TASK_ID: Dict[str, str] = {
22
+ "syntax_fix_invoice_totals": (
23
+ "Traceback (most recent call last):\n"
24
+ " File \"services/billing/reconciliation.py\", line 3\n"
25
+ " for record in records\n"
26
+ " ^\n"
27
+ "SyntaxError: expected ':'"
28
+ ),
29
+ "bug_fix_session_windows": (
30
+ "AssertionError: collapse_sessions([{'minute': 1}, {'minute': 3}, {'minute': 8}], 4)\n"
31
+ "Expected: [(1, 3), (8, 8)]\n"
32
+ "Actual: [(1, 8)]\n"
33
+ "Boundary handling merges the final session instead of starting a new one."
34
+ ),
35
+ "optimization_rank_active_users": (
36
+ "BenchmarkWarning: rank_active_users exceeded the 450ms budget on a nightly export fixture.\n"
37
+ "Profiler hint: repeated scans over the full event list and nested loops dominate runtime."
38
+ ),
39
+ }
40
+
41
+ SUMMARY_BY_TASK_ID: Dict[str, str] = {
42
+ "syntax_fix_invoice_totals": "Broken parser state in a billing helper blocks reconciliation jobs.",
43
+ "bug_fix_session_windows": "Session-boundary logic fails on inclusive idle-timeout edges.",
44
+ "optimization_rank_active_users": "A nightly ranking job is correct on small fixtures but too slow at production scale.",
45
+ }
46
+
47
+ CONTEXT_BY_TASK_ID: Dict[str, str] = {
48
+ "syntax_fix_invoice_totals": (
49
+ "Context window: this helper runs in an end-of-day billing reconciliation job. "
50
+ "Keep the public function signature intact and restore correct totals for mixed integer/string inputs."
51
+ ),
52
+ "bug_fix_session_windows": (
53
+ "Context window: this function groups sorted product analytics events into sessions for retention dashboards. "
54
+ "Boundary behavior must stay deterministic because downstream reports depend on it."
55
+ ),
56
+ "optimization_rank_active_users": (
57
+ "Context window: this pipeline feeds a nightly export on a small CPU instance. "
58
+ "Maintain identical output ordering while improving scalability on larger event volumes."
59
+ ),
60
+ }
61
+
62
+
63
+ def _prototype_text(
64
+ task_id: str,
65
+ title: str,
66
+ description: str,
67
+ repo_summary: str,
68
+ goal: str,
69
+ visible_tests: List[str],
70
+ starter_code: str,
71
+ traceback_text: str,
72
+ ) -> str:
73
+ visible = "\n".join(f"- {item}" for item in visible_tests) or "- none"
74
+ return (
75
+ f"Title: {title}\n"
76
+ f"Problem: {description}\n"
77
+ f"Repo context: {repo_summary}\n"
78
+ f"Goal: {goal}\n"
79
+ f"Observed failure:\n{traceback_text}\n"
80
+ f"Visible checks:\n{visible}\n"
81
+ f"Candidate code:\n{starter_code}\n"
82
+ f"Task id: {task_id}\n"
83
+ )
84
+
85
+
86
+ def build_examples() -> List[TriageExample]:
87
+ """Create stable UI examples from the task catalog."""
88
+
89
+ examples: List[TriageExample] = []
90
+ for task in list_tasks():
91
+ label = TASK_KIND_TO_LABEL[task.task_kind]
92
+ examples.append(
93
+ TriageExample(
94
+ key=task.task_id,
95
+ title=task.title,
96
+ label=label,
97
+ summary=SUMMARY_BY_TASK_ID[task.task_id],
98
+ code=task.starter_code,
99
+ traceback_text=TRACEBACK_BY_TASK_ID[task.task_id],
100
+ context_window=CONTEXT_BY_TASK_ID[task.task_id],
101
+ task_id=task.task_id,
102
+ )
103
+ )
104
+ return examples
105
+
106
+
107
+ def build_prototypes() -> List[TriagePrototype]:
108
+ """Build canonical triage prototypes from the OpenEnv tasks."""
109
+
110
+ prototypes: List[TriagePrototype] = []
111
+ for task in list_tasks():
112
+ traceback_text = TRACEBACK_BY_TASK_ID[task.task_id]
113
+ prototypes.append(
114
+ TriagePrototype(
115
+ task_id=task.task_id,
116
+ title=task.title,
117
+ label=TASK_KIND_TO_LABEL[task.task_kind],
118
+ summary=SUMMARY_BY_TASK_ID[task.task_id],
119
+ reference_text=_prototype_text(
120
+ task.task_id,
121
+ task.title,
122
+ task.task_description,
123
+ task.repo_summary,
124
+ task.goal,
125
+ list(task.visible_tests),
126
+ task.reference_code,
127
+ traceback_text,
128
+ ),
129
+ starter_code=task.starter_code,
130
+ reference_code=task.reference_code,
131
+ traceback_text=traceback_text,
132
+ )
133
+ )
134
+ return prototypes
triage_models.py CHANGED
@@ -1,79 +1,79 @@
1
- """Typed models for TorchReview Copilot outputs and examples."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Dict, List, Literal
6
-
7
- from pydantic import BaseModel, Field
8
-
9
-
10
- IssueLabel = Literal["syntax", "logic", "performance"]
11
- RiskLevel = Literal["low", "medium", "high"]
12
-
13
-
14
- class TriageSignal(BaseModel):
15
- """One extracted signal used during issue classification."""
16
-
17
- name: str
18
- value: str
19
- impact: Literal["syntax", "logic", "performance", "mixed"] = "mixed"
20
- weight: float = Field(..., ge=0.0, le=1.0)
21
- evidence: str = ""
22
-
23
-
24
- class PrototypeMatch(BaseModel):
25
- """Nearest known bug pattern from the built-in task catalog."""
26
-
27
- task_id: str
28
- title: str
29
- label: IssueLabel
30
- similarity: float = Field(..., ge=0.0, le=1.0)
31
- summary: str
32
- rationale: str
33
-
34
-
35
- class TriageExample(BaseModel):
36
- """Example payload exposed in the demo UI."""
37
-
38
- key: str
39
- title: str
40
- label: IssueLabel
41
- summary: str
42
- code: str
43
- traceback_text: str
44
- context_window: str
45
- task_id: str
46
-
47
-
48
- class TriagePrototype(BaseModel):
49
- """Canonical issue-pattern representation embedded by the triage engine."""
50
-
51
- task_id: str
52
- title: str
53
- label: IssueLabel
54
- summary: str
55
- reference_text: str
56
- starter_code: str
57
- reference_code: str
58
- traceback_text: str
59
-
60
-
61
- class TriageResult(BaseModel):
62
- """Structured output produced by the triage pipeline."""
63
-
64
- issue_label: IssueLabel
65
- confidence_scores: Dict[str, float]
66
- repair_risk: RiskLevel
67
- ml_quality_score: float = Field(..., ge=0.0, le=1.0)
68
- lint_score: float = Field(..., ge=0.0, le=1.0)
69
- complexity_penalty: float = Field(..., ge=0.0, le=1.0)
70
- reward_score: float = Field(..., ge=0.0, le=1.0)
71
- summary: str
72
- matched_pattern: PrototypeMatch
73
- repair_plan: List[str]
74
- suggested_next_action: str
75
- extracted_signals: List[TriageSignal] = Field(default_factory=list)
76
- model_backend: str
77
- model_id: str
78
- inference_notes: List[str] = Field(default_factory=list)
79
- analysis_time_ms: float = Field(..., ge=0.0)
 
1
+ """Typed models for TorchReview Copilot outputs and examples."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Literal
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ IssueLabel = Literal["syntax", "logic", "performance"]
11
+ RiskLevel = Literal["low", "medium", "high"]
12
+
13
+
14
+ class TriageSignal(BaseModel):
15
+ """One extracted signal used during issue classification."""
16
+
17
+ name: str
18
+ value: str
19
+ impact: Literal["syntax", "logic", "performance", "mixed"] = "mixed"
20
+ weight: float = Field(..., ge=0.0, le=1.0)
21
+ evidence: str = ""
22
+
23
+
24
+ class PrototypeMatch(BaseModel):
25
+ """Nearest known bug pattern from the built-in task catalog."""
26
+
27
+ task_id: str
28
+ title: str
29
+ label: IssueLabel
30
+ similarity: float = Field(..., ge=0.0, le=1.0)
31
+ summary: str
32
+ rationale: str
33
+
34
+
35
+ class TriageExample(BaseModel):
36
+ """Example payload exposed in the demo UI."""
37
+
38
+ key: str
39
+ title: str
40
+ label: IssueLabel
41
+ summary: str
42
+ code: str
43
+ traceback_text: str
44
+ context_window: str
45
+ task_id: str
46
+
47
+
48
+ class TriagePrototype(BaseModel):
49
+ """Canonical issue-pattern representation embedded by the triage engine."""
50
+
51
+ task_id: str
52
+ title: str
53
+ label: IssueLabel
54
+ summary: str
55
+ reference_text: str
56
+ starter_code: str
57
+ reference_code: str
58
+ traceback_text: str
59
+
60
+
61
+ class TriageResult(BaseModel):
62
+ """Structured output produced by the triage pipeline."""
63
+
64
+ issue_label: IssueLabel
65
+ confidence_scores: Dict[str, float]
66
+ repair_risk: RiskLevel
67
+ ml_quality_score: float = Field(..., ge=0.0, le=1.0)
68
+ lint_score: float = Field(..., ge=0.0, le=1.0)
69
+ complexity_penalty: float = Field(..., ge=0.0, le=1.0)
70
+ reward_score: float = Field(..., ge=0.0, le=1.0)
71
+ summary: str
72
+ matched_pattern: PrototypeMatch
73
+ repair_plan: List[str]
74
+ suggested_next_action: str
75
+ extracted_signals: List[TriageSignal] = Field(default_factory=list)
76
+ model_backend: str
77
+ model_id: str
78
+ inference_notes: List[str] = Field(default_factory=list)
79
+ analysis_time_ms: float = Field(..., ge=0.0)
utils/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- """Utility helpers for AST parsing and complexity scoring."""
2
-
3
- from .ast_parser import parse_code_structure
4
- from .complexity import estimate_complexity
5
-
6
- __all__ = ["parse_code_structure", "estimate_complexity"]
 
1
+ """Utility helpers for AST parsing and complexity scoring."""
2
+
3
+ from .ast_parser import parse_code_structure
4
+ from .complexity import estimate_complexity
5
+
6
+ __all__ = ["parse_code_structure", "estimate_complexity"]