Samarthrr commited on
Commit
e740563
·
verified ·
1 Parent(s): ee63cc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -35
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import ast
2
  import torch
3
  import torch.nn as nn
4
- from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  from typing import Optional
7
  from transformers import (
@@ -12,8 +12,16 @@ from transformers import (
12
  )
13
  import pandas as pd
14
  import os
 
15
 
16
- app = FastAPI(title="Revcode AI Unified Orchestrator")
 
 
 
 
 
 
 
17
 
18
  # ---------------------------------------------------------
19
  # 1. DATA MODELS
@@ -23,15 +31,21 @@ class CodeInput(BaseModel):
23
  filename: Optional[str] = "snippet.js"
24
 
25
  # ---------------------------------------------------------
26
- # 2. ADVANCED SECURITY SCANNER (The "Brain" + XAI)
27
  # ---------------------------------------------------------
28
  class DeepVulnerabilityScanner:
29
  def __init__(self):
30
- # SOTA model fine-tuned specifically on the Devign code vulnerability dataset
31
- self.model_name = "mahdin70/codebert-devign-code-vulnerability-detector"
32
- self.tokenizer_name = "microsoft/codebert-base"
33
-
34
- print(f"Loading Specialized Security Scanner ({self.model_name})...")
 
 
 
 
 
 
35
  self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
36
  self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
37
  self.model.eval()
@@ -44,7 +58,6 @@ class DeepVulnerabilityScanner:
44
  probs = torch.softmax(logits, dim=1)
45
  vuln_prob = probs[0][1].item()
46
 
47
- # XAI Layer: Tuned for the specialized model's confidence thresholds
48
  reasoning = "Analyzing code logic for Devign-pattern vulnerabilities."
49
  if vuln_prob > 0.9:
50
  reasoning = "CRITICAL: High-confidence fingerprint of a known vulnerability pattern (e.g., Buffer Overflow, Improper Sanitization)."
@@ -67,31 +80,24 @@ class StructuralScanner:
67
  @staticmethod
68
  def scan_patterns(code: str, filename: str) -> list:
69
  findings = []
70
-
71
- # Pattern 1: Command Injection
72
  if "os.system(" in code or "subprocess.Popen(..., shell=True)" in code:
73
  findings.append({
74
  "type": "Security",
75
  "title": "Command Injection Risk",
76
  "reasoning": "Detected use of shell=True or os.system which can lead to Remote Code Execution."
77
  })
78
-
79
- # Pattern 2: Pickle / Deserialization
80
  if "pickle.load" in code or "yaml.load(..., Loader=None)" in code:
81
  findings.append({
82
  "type": "Security",
83
  "title": "Insecure Deserialization",
84
  "reasoning": "Insecure loading of serialized data can lead to arbitrary code execution."
85
  })
86
-
87
- # Pattern 3: Hardcoded Credentials
88
  if "Password =" in code or "API_KEY =" in code:
89
  findings.append({
90
  "type": "Compliance",
91
  "title": "Hardcoded Secret",
92
  "reasoning": "Sensitive credentials found in source code. Use environment variables instead."
93
  })
94
-
95
  return findings
96
 
97
  # ---------------------------------------------------------
@@ -106,10 +112,8 @@ class AutomatedRepairEngine:
106
  self.model.eval()
107
 
108
  def repair(self, buggy_code: str, filename: str) -> str:
109
- # Context Injection: Add filename to the prompt
110
  prompt = f"Fix the security vulnerability in this {filename} file: {buggy_code}"
111
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
112
-
113
  with torch.no_grad():
114
  outputs = self.model.generate(
115
  **inputs,
@@ -118,7 +122,6 @@ class AutomatedRepairEngine:
118
  temperature=0.7,
119
  early_stopping=True
120
  )
121
-
122
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
123
 
124
  # ---------------------------------------------------------
@@ -141,9 +144,9 @@ repairer = None
141
  struct_scanner = StructuralScanner()
142
  guardrails = Guardrails()
143
 
144
- def get_scanner():
145
  global scanner
146
- if scanner is None:
147
  scanner = DeepVulnerabilityScanner()
148
  return scanner
149
 
@@ -154,23 +157,46 @@ def get_repairer():
154
  return repairer
155
 
156
  # ---------------------------------------------------------
157
- # 7. API ENDPOINTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # ---------------------------------------------------------
159
  @app.get("/")
160
  async def health():
161
- return {"status": "Revcode AI ULTRA Orchestrator Operational", "features": ["XAI", "Structural-Scan", "Context-Injection"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  @app.post("/analyze")
164
  async def analyze_security(data: CodeInput):
165
  eng = get_scanner()
166
-
167
- # 1. Neural Scan (XAI)
168
  res = eng.scan(data.code)
169
-
170
- # 2. Structural Scan (Mini-Semgrep)
171
  structural_findings = struct_scanner.scan_patterns(data.code, data.filename)
172
-
173
- # Merge reasoning from both layers
174
  if structural_findings:
175
  res["is_vulnerable"] = True
176
  res["reasoning"] += " | Structural rules flagged: " + ", ".join([f['title'] for f in structural_findings])
@@ -182,20 +208,15 @@ async def analyze_security(data: CodeInput):
182
  "verdict": res["verdict"],
183
  "reasoning": res["reasoning"],
184
  "structural_findings": structural_findings,
 
185
  "provider": "DeepScanner-ULTRA"
186
  }
187
 
188
  @app.post("/fix")
189
  async def fix_code(data: CodeInput):
190
  rep = get_repairer()
191
-
192
- # Generate context-aware fix
193
  suggestion = rep.repair(data.code, data.filename)
194
-
195
- # Heuristic layer removed to allow the neural surgeon to handle repairs with 100% precision.
196
-
197
  is_valid, msg = guardrails.validate(suggestion)
198
-
199
  return {
200
  "suggestion": suggestion,
201
  "guardrail_status": "PASSED" if is_valid else "FAILED",
 
1
  import ast
2
  import torch
3
  import torch.nn as nn
4
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
5
  from pydantic import BaseModel
6
  from typing import Optional
7
  from transformers import (
 
12
  )
13
  import pandas as pd
14
  import os
15
+ import threading
16
 
17
+ # Import the training function
18
+ from train_engine import train_on_devign
19
+
20
+ app = FastAPI(title="Revcode AI ULTRA Orchestrator")
21
+
22
+ # Global training status
23
+ training_lock = threading.Lock()
24
+ is_training = False
25
 
26
  # ---------------------------------------------------------
27
  # 1. DATA MODELS
 
31
  filename: Optional[str] = "snippet.js"
32
 
33
  # ---------------------------------------------------------
34
+ # 2. ADVANCED SECURITY SCANNER (CodeBERT-Devign + XAI)
35
  # ---------------------------------------------------------
36
  class DeepVulnerabilityScanner:
37
  def __init__(self):
38
+ # We check if a locally trained model exists, otherwise use the base
39
+ local_model = "./trained_model"
40
+ if os.path.exists(local_model):
41
+ self.model_name = local_model
42
+ self.tokenizer_name = local_model
43
+ print(f"Loading Locally Trained Security Scanner ({self.model_name})...")
44
+ else:
45
+ self.model_name = "mahdin70/codebert-devign-code-vulnerability-detector"
46
+ self.tokenizer_name = "microsoft/codebert-base"
47
+ print(f"Loading SOTA Security Scanner ({self.model_name})...")
48
+
49
  self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
50
  self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
51
  self.model.eval()
 
58
  probs = torch.softmax(logits, dim=1)
59
  vuln_prob = probs[0][1].item()
60
 
 
61
  reasoning = "Analyzing code logic for Devign-pattern vulnerabilities."
62
  if vuln_prob > 0.9:
63
  reasoning = "CRITICAL: High-confidence fingerprint of a known vulnerability pattern (e.g., Buffer Overflow, Improper Sanitization)."
 
80
  @staticmethod
81
  def scan_patterns(code: str, filename: str) -> list:
82
  findings = []
 
 
83
  if "os.system(" in code or "subprocess.Popen(..., shell=True)" in code:
84
  findings.append({
85
  "type": "Security",
86
  "title": "Command Injection Risk",
87
  "reasoning": "Detected use of shell=True or os.system which can lead to Remote Code Execution."
88
  })
 
 
89
  if "pickle.load" in code or "yaml.load(..., Loader=None)" in code:
90
  findings.append({
91
  "type": "Security",
92
  "title": "Insecure Deserialization",
93
  "reasoning": "Insecure loading of serialized data can lead to arbitrary code execution."
94
  })
 
 
95
  if "Password =" in code or "API_KEY =" in code:
96
  findings.append({
97
  "type": "Compliance",
98
  "title": "Hardcoded Secret",
99
  "reasoning": "Sensitive credentials found in source code. Use environment variables instead."
100
  })
 
101
  return findings
102
 
103
  # ---------------------------------------------------------
 
112
  self.model.eval()
113
 
114
  def repair(self, buggy_code: str, filename: str) -> str:
 
115
  prompt = f"Fix the security vulnerability in this {filename} file: {buggy_code}"
116
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
117
  with torch.no_grad():
118
  outputs = self.model.generate(
119
  **inputs,
 
122
  temperature=0.7,
123
  early_stopping=True
124
  )
 
125
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
126
 
127
  # ---------------------------------------------------------
 
144
  struct_scanner = StructuralScanner()
145
  guardrails = Guardrails()
146
 
147
+ def get_scanner(force_reload=False):
148
  global scanner
149
+ if scanner is None or force_reload:
150
  scanner = DeepVulnerabilityScanner()
151
  return scanner
152
 
 
157
  return repairer
158
 
159
  # ---------------------------------------------------------
160
+ # 7. TRAINING WRAPPER
161
+ # ---------------------------------------------------------
162
+ def run_training():
163
+ global is_training
164
+ with training_lock:
165
+ is_training = True
166
+ try:
167
+ print("--- STARTING BACKGROUND TRAINING CYCLE ---")
168
+ train_on_devign(output_dir="./trained_model")
169
+ print("--- TRAINING CYCLE COMPLETED. RELOADING SCANNER ---")
170
+ get_scanner(force_reload=True)
171
+ finally:
172
+ with training_lock:
173
+ is_training = False
174
+
175
+ # ---------------------------------------------------------
176
+ # 8. API ENDPOINTS
177
  # ---------------------------------------------------------
178
  @app.get("/")
179
  async def health():
180
+ return {
181
+ "status": "Revcode AI ULTRA Orchestrator Operational",
182
+ "is_training": is_training,
183
+ "features": ["XAI", "Structural-Scan", "Context-Injection", "Auto-Train"]
184
+ }
185
+
186
+ @app.post("/train")
187
+ async def trigger_training(background_tasks: BackgroundTasks):
188
+ global is_training
189
+ if is_training:
190
+ return {"status": "error", "message": "Training already in progress."}
191
+
192
+ background_tasks.add_task(run_training)
193
+ return {"status": "success", "message": "Training started in background."}
194
 
195
  @app.post("/analyze")
196
  async def analyze_security(data: CodeInput):
197
  eng = get_scanner()
 
 
198
  res = eng.scan(data.code)
 
 
199
  structural_findings = struct_scanner.scan_patterns(data.code, data.filename)
 
 
200
  if structural_findings:
201
  res["is_vulnerable"] = True
202
  res["reasoning"] += " | Structural rules flagged: " + ", ".join([f['title'] for f in structural_findings])
 
208
  "verdict": res["verdict"],
209
  "reasoning": res["reasoning"],
210
  "structural_findings": structural_findings,
211
+ "is_training": is_training,
212
  "provider": "DeepScanner-ULTRA"
213
  }
214
 
215
  @app.post("/fix")
216
  async def fix_code(data: CodeInput):
217
  rep = get_repairer()
 
 
218
  suggestion = rep.repair(data.code, data.filename)
 
 
 
219
  is_valid, msg = guardrails.validate(suggestion)
 
220
  return {
221
  "suggestion": suggestion,
222
  "guardrail_status": "PASSED" if is_valid else "FAILED",