CodeCommunity commited on
Commit
2781db5
·
verified ·
1 Parent(s): f5f0249

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +166 -40
app/main.py CHANGED
@@ -1,104 +1,142 @@
1
- # main.py - Final Fixed Version
2
- from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel
4
- from typing import List, Optional
5
  import re
6
  import logging
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- from app.services.reviewer_service import AIReviewerService
9
- from app.predictor import classifier, guide_generator
 
10
 
11
  # 1. Setup Logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- # 2. Initialize FastAPI and Services
16
  app = FastAPI(title="GitGud AI Service")
17
- reviewer_service = AIReviewerService()
18
 
19
- # 3. Data Models (Order matters: ReviewRequest needs FileRequest)
 
 
 
 
20
  class FileRequest(BaseModel):
21
  fileName: str
22
  content: Optional[str] = None
 
23
 
24
- class ReviewRequest(BaseModel):
25
  files: List[FileRequest]
26
 
27
  class GuideRequest(BaseModel):
28
  repoName: str
29
  filePaths: List[str]
30
 
 
 
 
 
 
 
 
 
 
 
31
  # 4. Endpoints
32
 
33
  @app.get("/")
34
  def health_check():
35
- """Checks server status and GPU availability."""
36
  return {
37
  "status": "online",
38
  "model": "microsoft/codebert-base",
39
- "device": classifier.device,
 
40
  }
41
 
 
 
 
 
 
 
42
  @app.post("/classify")
43
  async def classify_file(request: FileRequest):
44
- """Classifies file into architectural layers."""
45
  try:
46
  result = classifier.predict(request.fileName, request.content)
 
 
 
 
 
 
 
47
  return {
48
- "fileName": request.fileName,
49
  "layer": result["label"],
50
  "confidence": result["confidence"],
51
  "embedding": result["embedding"]
52
  }
53
  except Exception as e:
54
  logger.error(f"Classify failed: {e}")
 
55
  raise HTTPException(status_code=500, detail=str(e))
56
 
57
- @app.post("/generate-guide")
58
- async def generate_guide(request: GuideRequest):
59
- """Generates markdown guides for repositories."""
60
  try:
61
- markdown = guide_generator.generate_markdown(request.repoName, request.filePaths)
62
- return {"markdown": markdown}
63
- except Exception as e:
64
- logger.error(f"Guide generation failed: {e}")
65
- raise HTTPException(status_code=500, detail=str(e))
 
 
66
 
67
- @app.post("/review")
68
- async def review_code(request: ReviewRequest):
69
- """Detects security and logic issues in batches of files."""
70
- try:
71
- # Call the batch review logic from your service
72
- results = reviewer_service.review_batch_code(request.files)
73
- return {"reviews": results}
74
  except Exception as e:
75
- logger.error(f"Review endpoint failed: {e}")
76
  raise HTTPException(status_code=500, detail=str(e))
77
 
78
  @app.post("/repo-dashboard-stats")
79
- async def get_dashboard_stats(request: ReviewRequest):
 
80
  try:
81
- raw_reviews = reviewer_service.review_batch_code(request.files)
82
 
83
  # 1. Security Count
84
  total_vulns = sum(len(r.get("vulnerabilities", [])) for r in raw_reviews)
85
 
86
- # 2. Performance Ratio (Maintainability)
87
- # We use a default of 8 if the AI misses a file to avoid 0% scores
88
  scores = [r.get("metrics", {}).get("maintainability", 8) for r in raw_reviews]
89
  avg_maintainability = (sum(scores) / len(scores)) * 10 if scores else 0
90
 
91
- # 3. API Sniffing
92
  found_apis = []
93
  for f in request.files:
94
  if f.content:
95
- # Regex looks for common route decorators or methods
96
  matches = re.findall(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', f.content.lower())
97
  for match in matches:
98
  found_apis.append(f"/{match}")
99
 
100
  # 4. Repo Health Calculation
101
- # Every security issue drops health by 10 points
102
  health_score = max(10, 100 - (total_vulns * 10))
103
 
104
  return {
@@ -112,8 +150,96 @@ async def get_dashboard_stats(request: ReviewRequest):
112
  logger.error(f"Dashboard stats failed: {e}")
113
  raise HTTPException(status_code=500, detail="Failed to aggregate repository stats")
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # 5. Application Entry Point
116
  if __name__ == "__main__":
117
- import uvicorn
118
- # Port 7860 is mandatory for Hugging Face Spaces
119
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os
 
 
 
2
  import re
3
  import logging
4
+ import traceback
5
+ import time
6
+ from typing import List, Optional, Dict
7
+ from dotenv import load_dotenv
8
+
9
+ from fastapi import FastAPI, HTTPException
10
+ from pydantic import BaseModel
11
+ import uvicorn
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
 
16
+ from app.predictor import classifier, guide_generator, reviewer
17
+ # Note: AIReviewerService from the first version is typically
18
+ # the underlying service for the 'reviewer' object in the second.
19
 
20
  # 1. Setup Logging
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ # 2. Initialize FastAPI
25
  app = FastAPI(title="GitGud AI Service")
26
+ main = app # Alias for compatibility
27
 
28
+ # Global embedding cache
29
+ # Structure: { "repo_name": { "file_path": [embedding_vector] } }
30
+ REPO_CACHE: Dict[str, Dict[str, List[float]]] = {}
31
+
32
+ # 3. Data Models
33
  class FileRequest(BaseModel):
34
  fileName: str
35
  content: Optional[str] = None
36
+ repoName: Optional[str] = None
37
 
38
+ class BatchReviewRequest(BaseModel):
39
  files: List[FileRequest]
40
 
41
  class GuideRequest(BaseModel):
42
  repoName: str
43
  filePaths: List[str]
44
 
45
+ class SearchRequest(BaseModel):
46
+ query: str
47
+ embeddings: Optional[Dict[str, List[float]]] = None # Path -> Embedding
48
+ repoName: Optional[str] = None
49
+
50
+ class ChatRequest(BaseModel):
51
+ query: str
52
+ context: List[Dict[str, str]] # List of { "fileName": str, "content": str }
53
+ repoName: str
54
+
55
  # 4. Endpoints
56
 
57
  @app.get("/")
58
  def health_check():
59
+ """Checks server status, GPU availability, and cached data."""
60
  return {
61
  "status": "online",
62
  "model": "microsoft/codebert-base",
63
+ "device": getattr(classifier, "device", "cpu"),
64
+ "cached_repos": list(REPO_CACHE.keys()),
65
  }
66
 
67
+ @app.get("/usage")
68
+ def get_usage():
69
+ """Returns AI Service usage statistics."""
70
+ from app.core.model_loader import llm_engine
71
+ return llm_engine.get_usage_stats()
72
+
73
  @app.post("/classify")
74
  async def classify_file(request: FileRequest):
75
+ """Classifies file into architectural layers and caches embeddings."""
76
  try:
77
  result = classifier.predict(request.fileName, request.content)
78
+
79
+ # Cache embedding if repoName is provided
80
+ if request.repoName:
81
+ if request.repoName not in REPO_CACHE:
82
+ REPO_CACHE[request.repoName] = {}
83
+ REPO_CACHE[request.repoName][request.fileName] = result["embedding"]
84
+
85
  return {
86
+ "fileName": request.fileName,
87
  "layer": result["label"],
88
  "confidence": result["confidence"],
89
  "embedding": result["embedding"]
90
  }
91
  except Exception as e:
92
  logger.error(f"Classify failed: {e}")
93
+ traceback.print_exc()
94
  raise HTTPException(status_code=500, detail=str(e))
95
 
96
+ @app.post("/review-batch-code")
97
+ async def review_batch_code(request: BatchReviewRequest):
98
+ """Batch review with detailed metrics and suggestions."""
99
  try:
100
+ reviews = reviewer.service.review_batch_code(request.files)
101
+ total_files = len(reviews)
102
+ total_vulns = sum(len(r.get("vulnerabilities", [])) for r in reviews)
103
+
104
+ # Calculate Average Maintainability
105
+ m_scores = [r.get("metrics", {}).get("maintainability", 8) for r in reviews]
106
+ avg_maint = sum(m_scores) / max(total_files, 1)
107
 
108
+ return {
109
+ "totalFiles": total_files,
110
+ "totalVulnerabilities": total_vulns,
111
+ "averageMaintainability": round(avg_maint, 1),
112
+ "results": reviews,
113
+ }
 
114
  except Exception as e:
115
+ traceback.print_exc()
116
  raise HTTPException(status_code=500, detail=str(e))
117
 
118
  @app.post("/repo-dashboard-stats")
119
+ async def get_dashboard_stats(request: BatchReviewRequest):
120
+ """Aggregated stats for frontend dashboards including health and API sniffing."""
121
  try:
122
+ raw_reviews = reviewer.service.review_batch_code(request.files)
123
 
124
  # 1. Security Count
125
  total_vulns = sum(len(r.get("vulnerabilities", [])) for r in raw_reviews)
126
 
127
+ # 2. Performance/Maintainability Ratio
 
128
  scores = [r.get("metrics", {}).get("maintainability", 8) for r in raw_reviews]
129
  avg_maintainability = (sum(scores) / len(scores)) * 10 if scores else 0
130
 
131
+ # 3. API Sniffing (Regex)
132
  found_apis = []
133
  for f in request.files:
134
  if f.content:
 
135
  matches = re.findall(r'(?:get|post|put|delete|patch)\([\'"]\/(.*?)[\'"]', f.content.lower())
136
  for match in matches:
137
  found_apis.append(f"/{match}")
138
 
139
  # 4. Repo Health Calculation
 
140
  health_score = max(10, 100 - (total_vulns * 10))
141
 
142
  return {
 
150
  logger.error(f"Dashboard stats failed: {e}")
151
  raise HTTPException(status_code=500, detail="Failed to aggregate repository stats")
152
 
153
+ @app.post("/analyze-file")
154
+ async def analyze_file(request: FileRequest):
155
+ """Deep analysis: Summary, Tags, and Layer Classification."""
156
+ try:
157
+ result = classifier.predict(request.fileName, request.content)
158
+ summary = classifier.generate_file_summary(request.content, request.fileName)
159
+ tags = classifier.extract_tags(request.content, request.fileName)
160
+
161
+ if request.repoName:
162
+ if request.repoName not in REPO_CACHE:
163
+ REPO_CACHE[request.repoName] = {}
164
+ REPO_CACHE[request.repoName][request.fileName] = result["embedding"]
165
+
166
+ return {
167
+ "fileName": request.fileName,
168
+ "layer": result["label"],
169
+ "summary": summary,
170
+ "tags": tags,
171
+ "embedding": result["embedding"],
172
+ }
173
+ except Exception as e:
174
+ traceback.print_exc()
175
+ raise HTTPException(status_code=500, detail=str(e))
176
+
177
+ @app.post("/semantic-search")
178
+ async def semantic_search(request: SearchRequest):
179
+ """Search code using natural language and vector similarity."""
180
+ try:
181
+ embeddings = request.embeddings
182
+ if not embeddings and request.repoName and request.repoName in REPO_CACHE:
183
+ embeddings = REPO_CACHE[request.repoName]
184
+
185
+ if not embeddings:
186
+ return {"results": []}
187
+
188
+ results = classifier.semantic_search(request.query, embeddings)
189
+ return {"results": results}
190
+ except Exception as e:
191
+ traceback.print_exc()
192
+ raise HTTPException(status_code=500, detail=str(e))
193
+
194
+ @app.post("/chat")
195
+ async def chat(request: ChatRequest):
196
+ """RAG-based chat using provided file context."""
197
+ start_time = time.time()
198
+ logger.info(f"Received Chat Request for {request.repoName}")
199
+
200
+ try:
201
+ from app.core.model_loader import llm_engine
202
+
203
+ context_str = ""
204
+ for item in request.context:
205
+ context_str += f"--- FILE: {item['fileName']} ---\n{item['content']}\n\n"
206
+
207
+ has_context = len(request.context) > 0
208
+ prompt = f"""
209
+ You are "GitGud AI", an expert software architect.
210
+ Repository: "{request.repoName}"
211
+
212
+ INSTRUCTIONS:
213
+ 1. Use the provided CONTEXT to answer.
214
+ 2. If context is missing, state: "I am using general knowledge as I don't have specific snippets for this."
215
+ 3. Use markdown for code.
216
+
217
+ CONTEXT:
218
+ {context_str if has_context else "[(NO CODE SNIPPETS PROVIDED)]"}
219
+
220
+ USER QUESTION:
221
+ {request.query}
222
+ """
223
+ response = llm_engine.generate_text(prompt)
224
+
225
+ logger.info(f"Chat response generated in {time.time() - start_time:.2f}s")
226
+ return {"response": response}
227
+ except Exception as e:
228
+ traceback.print_exc()
229
+ raise HTTPException(status_code=500, detail=str(e))
230
+
231
+ @app.post("/generate-guide")
232
+ async def generate_guide(request: GuideRequest):
233
+ """Generates markdown documentation for the repo."""
234
+ try:
235
+ markdown = guide_generator.generate_markdown(request.repoName, request.filePaths)
236
+ return {"markdown": markdown}
237
+ except Exception as e:
238
+ traceback.print_exc()
239
+ raise HTTPException(status_code=500, detail=str(e))
240
+
241
  # 5. Application Entry Point
242
  if __name__ == "__main__":
243
+ # Note: Using 7860 for HF Spaces compatibility, change to 8000 if preferred for local dev
244
+ port = int(os.environ.get("PORT", 7860))
245
+ uvicorn.run(app, host="0.0.0.0", port=port)