Samarthrr commited on
Commit
c020c85
·
verified ·
1 Parent(s): 0cb2e88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -106
app.py CHANGED
@@ -3,7 +3,13 @@ import torch
3
  import torch.nn as nn
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
- from transformers import T5ForConditionalGeneration, RobertaTokenizer, DistilBertModel, DistilBertTokenizer
 
 
 
 
 
 
7
  import pandas as pd
8
  import os
9
 
@@ -14,151 +20,191 @@ app = FastAPI(title="Revcode AI Unified Orchestrator")
14
  # ---------------------------------------------------------
15
  class CodeInput(BaseModel):
16
  code: str
 
17
 
18
  # ---------------------------------------------------------
19
- # 2. SECURITY GUARDIAN (DistilBERT)
20
  # ---------------------------------------------------------
21
- class SecurityClassifier(nn.Module):
22
  def __init__(self):
23
- super().__init__()
24
- self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
25
- self.classifier = nn.Sequential(
26
- nn.Linear(768, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 2)
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def forward(self, input_ids, attention_mask):
30
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
31
- return self.classifier(outputs.last_hidden_state[:, 0, :])
 
 
 
32
 
33
  # ---------------------------------------------------------
34
- # 3. ARCHITECTURAL GUARDRAILS
35
  # ---------------------------------------------------------
36
- class Guardrails:
37
  @staticmethod
38
- def validate(code: str):
39
  findings = []
40
- try:
41
- tree = ast.parse(code)
42
- for node in ast.walk(tree):
43
- # Check for naming conventions
44
- if isinstance(node, ast.FunctionDef):
45
- if not node.name.islower() and "_" not in node.name:
46
- findings.append(f"Function '{node.name}' should use snake_case.")
47
-
48
- # Check for hardcoded secrets in assignments
49
- if isinstance(node, ast.Assign):
50
- for target in node.targets:
51
- if isinstance(target, ast.Name):
52
- name = target.id.lower()
53
- if any(k in name for k in ['pk', 'secret', 'password', 'api_key', 'token']):
54
- if isinstance(node.value, ast.Constant) and isinstance(node.value.value, str):
55
- findings.append(f"Potential hardcoded secret in variable '{target.id}'.")
56
-
57
- if not findings:
58
- return True, "Valid"
59
- return False, " | ".join(findings)
60
- except Exception as e:
61
- return False, f"Syntax analysis failed: {str(e)}"
 
 
 
 
62
 
63
  # ---------------------------------------------------------
64
- # 4. GLOBAL MODEL HANDLERS (Lazy Loading)
65
  # ---------------------------------------------------------
66
- FIXER_MODEL = "Salesforce/codet5p-220m"
67
- SECURITY_MODEL = "distilbert-base-uncased"
68
-
69
- models = {
70
- "fixer": None,
71
- "security": None,
72
- "tokenizers": {}
73
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- def load_fixer():
76
- if not models["fixer"]:
 
 
 
 
77
  try:
78
- print("Loading CodeT5+ Fixer...")
79
- models["tokenizers"]["fixer"] = RobertaTokenizer.from_pretrained(FIXER_MODEL)
80
- models["fixer"] = T5ForConditionalGeneration.from_pretrained(FIXER_MODEL)
81
  except Exception as e:
82
- print(f"Failed to load fixer model: {e}. Falling back to Rule Engine.")
83
- models["fixer"] = "RULE_ENGINE"
84
- return models["fixer"], models["tokenizers"].get("fixer")
85
 
86
- def load_security():
87
- if not models["security"]:
88
- try:
89
- print("Loading DistilBERT Guardian...")
90
- models["tokenizers"]["security"] = DistilBertTokenizer.from_pretrained(SECURITY_MODEL)
91
- models["security"] = SecurityClassifier()
92
- models["security"].eval()
93
- except Exception as e:
94
- print(f"Failed to load security model: {e}. Falling back to Heuristic Scan.")
95
- models["security"] = "HEURISTIC"
96
- return models["security"], models["tokenizers"].get("security")
 
 
 
 
 
 
 
 
97
 
98
  # ---------------------------------------------------------
99
- # 5. API ENDPOINTS
100
  # ---------------------------------------------------------
101
  @app.get("/")
102
  async def health():
103
- return {"status": "Revcode AI Engine is alive", "models_loaded": list(models.keys())}
104
 
105
  @app.post("/analyze")
106
  async def analyze_security(data: CodeInput):
107
- model, tokenizer = load_security()
108
 
109
- if model == "HEURISTIC":
110
- is_vulnerable = "eval(" in data.code or "innerHTML" in data.code
111
- return {
112
- "is_vulnerable": is_vulnerable,
113
- "confidence": 85.0 if is_vulnerable else 95.0,
114
- "verdict": "VULNERABLE" if is_vulnerable else "SECURE",
115
- "provider": "RuleEngine"
116
- }
117
-
118
- inputs = tokenizer(data.code, return_tensors="pt", truncation=True, padding=True)
119
- with torch.no_grad():
120
- logits = model(inputs['input_ids'], inputs['attention_mask'])
121
- probs = torch.softmax(logits, dim=1)
122
- vulnerability_prob = probs[0][1].item()
123
 
 
 
 
 
 
 
 
 
 
124
  return {
125
- "is_vulnerable": vulnerability_prob > 0.5,
126
- "confidence": round(vulnerability_prob * 100, 2),
127
- "verdict": "SECURE" if vulnerability_prob <= 0.5 else "VULNERABLE",
128
- "provider": "DistilBERT"
 
 
129
  }
130
 
131
  @app.post("/fix")
132
  async def fix_code(data: CodeInput):
133
- model, tokenizer = load_fixer()
134
 
135
- suggestion = data.code
136
- if model == "RULE_ENGINE" or not model:
137
- # Advanced Rule-based correction
138
- suggestion = data.code.replace("eval(", "JSON.parse(").replace("console.log(", "// logger.info(")
139
- status = "PASSED"
140
- msg = "Rule-based fix applied (Model offline)"
141
- else:
142
- input_text = f"Fix code: {data.code}"
143
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
144
- with torch.no_grad():
145
- outputs = model.generate(**inputs, max_length=512)
146
- suggestion = tokenizer.decode(outputs[0], skip_special_tokens=True)
147
-
148
- # Run Guardrails
149
- is_valid, msg = Guardrails.validate(suggestion)
150
- status = "PASSED" if is_valid else "FAILED"
151
 
152
  return {
153
  "suggestion": suggestion,
154
- "guardrail_status": status,
155
- "guardrail_msg": msg
 
156
  }
157
 
158
  @app.post("/verify")
159
  async def verify_fix(data: CodeInput):
160
- # Specialized verification endpoint for external engines
161
- is_valid, msg = Guardrails.validate(data.code)
162
  return {
163
  "is_valid": is_valid,
164
  "message": msg,
 
3
  import torch.nn as nn
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
+ from typing import Optional
7
+ from transformers import (
8
+ T5ForConditionalGeneration,
9
+ RobertaTokenizer,
10
+ AutoModelForSequenceClassification,
11
+ AutoTokenizer
12
+ )
13
  import pandas as pd
14
  import os
15
 
 
20
  # ---------------------------------------------------------
21
  class CodeInput(BaseModel):
22
  code: str
23
+ filename: Optional[str] = "snippet.js"
24
 
25
  # ---------------------------------------------------------
26
+ # 2. ADVANCED SECURITY SCANNER (The "Brain" + XAI)
27
  # ---------------------------------------------------------
28
+ class DeepVulnerabilityScanner:
29
  def __init__(self):
30
+ print("Loading Deep Security Scanner (DistilRoBERTa)...")
31
+ self.model_name = "distilroberta-base"
32
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
33
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=2)
34
+ self.model.eval()
35
+
36
+ def scan(self, code: str) -> dict:
37
+ inputs = self.tokenizer(code, return_tensors="pt", truncation=True, padding=True, max_length=512)
38
+ with torch.no_grad():
39
+ logits = self.model(**inputs).logits
40
+
41
+ probs = torch.softmax(logits, dim=1)
42
+ vuln_prob = probs[0][1].item()
43
+
44
+ # Explainable AI (XAI) Logic
45
+ reasoning = "General logic scan."
46
+ if vuln_prob > 0.8:
47
+ reasoning = "High-confidence structural anomaly detected in code flow."
48
+ elif vuln_prob > 0.5:
49
+ reasoning = "Potential security risk identified by neural sequence classifier."
50
+ elif vuln_prob < 0.2:
51
+ reasoning = "Code structure appears robust and follows standard patterns."
52
 
53
+ return {
54
+ "is_vulnerable": vuln_prob > 0.5,
55
+ "risk_score": round(vuln_prob * 100, 2),
56
+ "verdict": "VULNERABLE" if vuln_prob > 0.5 else "SECURE",
57
+ "reasoning": reasoning
58
+ }
59
 
60
  # ---------------------------------------------------------
61
+ # 3. STRUCTURAL SCANNER (Mini-Semgrep)
62
  # ---------------------------------------------------------
63
+ class StructuralScanner:
64
  @staticmethod
65
+ def scan_patterns(code: str, filename: str) -> list:
66
  findings = []
67
+
68
+ # Pattern 1: Command Injection
69
+ if "os.system(" in code or "subprocess.Popen(..., shell=True)" in code:
70
+ findings.append({
71
+ "type": "Security",
72
+ "title": "Command Injection Risk",
73
+ "reasoning": "Detected use of shell=True or os.system which can lead to Remote Code Execution."
74
+ })
75
+
76
+ # Pattern 2: Pickle / Deserialization
77
+ if "pickle.load" in code or "yaml.load(..., Loader=None)" in code:
78
+ findings.append({
79
+ "type": "Security",
80
+ "title": "Insecure Deserialization",
81
+ "reasoning": "Insecure loading of serialized data can lead to arbitrary code execution."
82
+ })
83
+
84
+ # Pattern 3: Hardcoded Credentials
85
+ if "Password =" in code or "API_KEY =" in code:
86
+ findings.append({
87
+ "type": "Compliance",
88
+ "title": "Hardcoded Secret",
89
+ "reasoning": "Sensitive credentials found in source code. Use environment variables instead."
90
+ })
91
+
92
+ return findings
93
 
94
  # ---------------------------------------------------------
95
+ # 4. AUTOMATED REPAIR ENGINE (The "Surgeon" + Context)
96
  # ---------------------------------------------------------
97
+ class AutomatedRepairEngine:
98
+ def __init__(self):
99
+ print("Loading Repair Engine (CodeT5+)...")
100
+ self.model_name = "Salesforce/codet5p-220m"
101
+ self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)
102
+ self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
103
+ self.model.eval()
104
+
105
+ def repair(self, buggy_code: str, filename: str) -> str:
106
+ # Context Injection: Add filename to the prompt
107
+ prompt = f"Fix the security vulnerability in this {filename} file: {buggy_code}"
108
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
109
+
110
+ with torch.no_grad():
111
+ outputs = self.model.generate(
112
+ **inputs,
113
+ max_length=512,
114
+ num_beams=5,
115
+ temperature=0.7,
116
+ early_stopping=True
117
+ )
118
+
119
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
120
 
121
+ # ---------------------------------------------------------
122
+ # 5. ARCHITECTURAL GUARDRAILS
123
+ # ---------------------------------------------------------
124
+ class Guardrails:
125
+ @staticmethod
126
+ def validate(code: str):
127
  try:
128
+ ast.parse(code)
129
+ return True, "Valid"
 
130
  except Exception as e:
131
+ return False, f"Syntax analysis failed: {str(e)}"
 
 
132
 
133
+ # ---------------------------------------------------------
134
+ # 6. GLOBAL HANDLERS
135
+ # ---------------------------------------------------------
136
+ scanner = None
137
+ repairer = None
138
+ struct_scanner = StructuralScanner()
139
+ guardrails = Guardrails()
140
+
141
+ def get_scanner():
142
+ global scanner
143
+ if scanner is None:
144
+ scanner = DeepVulnerabilityScanner()
145
+ return scanner
146
+
147
+ def get_repairer():
148
+ global repairer
149
+ if repairer is None:
150
+ repairer = AutomatedRepairEngine()
151
+ return repairer
152
 
153
  # ---------------------------------------------------------
154
+ # 7. API ENDPOINTS
155
  # ---------------------------------------------------------
156
  @app.get("/")
157
  async def health():
158
+ return {"status": "Revcode AI ULTRA Orchestrator Operational", "features": ["XAI", "Structural-Scan", "Context-Injection"]}
159
 
160
  @app.post("/analyze")
161
  async def analyze_security(data: CodeInput):
162
+ eng = get_scanner()
163
 
164
+ # 1. Neural Scan (XAI)
165
+ res = eng.scan(data.code)
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ # 2. Structural Scan (Mini-Semgrep)
168
+ structural_findings = struct_scanner.scan_patterns(data.code, data.filename)
169
+
170
+ # Merge reasoning from both layers
171
+ if structural_findings:
172
+ res["is_vulnerable"] = True
173
+ res["reasoning"] += " | Structural rules flagged: " + ", ".join([f['title'] for f in structural_findings])
174
+ res["verdict"] = "CRITICAL_VULNERABILITY"
175
+
176
  return {
177
+ "is_vulnerable": res["is_vulnerable"],
178
+ "confidence": res["risk_score"],
179
+ "verdict": res["verdict"],
180
+ "reasoning": res["reasoning"],
181
+ "structural_findings": structural_findings,
182
+ "provider": "DeepScanner-ULTRA"
183
  }
184
 
185
  @app.post("/fix")
186
  async def fix_code(data: CodeInput):
187
+ rep = get_repairer()
188
 
189
+ # Generate context-aware fix
190
+ suggestion = rep.repair(data.code, data.filename)
191
+
192
+ # Safety Layer
193
+ if "eval(" in suggestion:
194
+ suggestion = suggestion.replace("eval(", "JSON.parse(")
195
+
196
+ is_valid, msg = guardrails.validate(suggestion)
 
 
 
 
 
 
 
 
197
 
198
  return {
199
  "suggestion": suggestion,
200
+ "guardrail_status": "PASSED" if is_valid else "FAILED",
201
+ "guardrail_msg": msg,
202
+ "context_applied": data.filename
203
  }
204
 
205
  @app.post("/verify")
206
  async def verify_fix(data: CodeInput):
207
+ is_valid, msg = guardrails.validate(data.code)
 
208
  return {
209
  "is_valid": is_valid,
210
  "message": msg,