Sgridda commited on
Commit
b64e7a0
·
1 Parent(s): 5f40b94

added inferernce

Browse files
Files changed (1) hide show
  1. main.py +67 -47
main.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import re
6
  import json
7
  from fastapi.responses import HTMLResponse
 
8
 
9
 
10
  # ----------------------------
@@ -14,6 +15,10 @@ from fastapi.responses import HTMLResponse
14
  MODEL_NAME = "Salesforce/codegen-350M-mono"
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
 
 
 
17
  # ----------------------------
18
  # 2. FastAPI App Initialization
19
  # ----------------------------
@@ -75,61 +80,76 @@ class ReviewResponse(BaseModel):
75
 
76
  def run_ai_inference(diff: str) -> str:
77
  """
78
- Runs the AI model to get the review.
79
  """
80
- if not model or not tokenizer:
81
- raise RuntimeError("Model is not loaded.")
82
-
83
- # Simple, direct prompt for codegen-350M-mono
84
- prompt = f"""Code:
85
- {diff[:500]}
86
-
87
- Review: This code could be improved by adding"""
88
- encoded = tokenizer(
89
- prompt,
90
- return_tensors="pt",
91
- max_length=512, # Reduced from 1024 for faster processing
92
- truncation=True,
93
- padding="max_length"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
- input_ids = encoded["input_ids"]
96
- attention_mask = encoded["attention_mask"]
97
- with torch.no_grad():
98
- outputs = model.generate(
99
- input_ids=input_ids,
100
- attention_mask=attention_mask,
101
- max_new_tokens=32, # Further reduced for speed
102
- do_sample=True,
103
- temperature=0.9,
104
- top_p=0.85,
105
- num_return_sequences=1,
106
- pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
107
- eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
108
- use_cache=True
109
- )
110
- response_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
111
-
112
  # Clean up the response
113
  response_text = response_text.strip()
114
-
115
- # Remove artifacts and clean up
116
- if response_text.startswith("adding"):
117
- response_text = "Adding " + response_text[6:]
118
-
119
- # Take only the first sentence or meaningful phrase
120
- sentences = response_text.split('.')
121
- if sentences and len(sentences[0].strip()) > 10:
122
- review = sentences[0].strip() + "."
 
123
  else:
124
- # Fallback to first meaningful line
125
  lines = [line.strip() for line in response_text.split('\n') if line.strip()]
126
- if lines and len(lines[0]) > 5:
127
- review = lines[0]
128
- if not review.endswith('.'):
129
- review += "."
 
 
 
 
 
130
  else:
131
  review = "Consider adding proper documentation and error handling."
132
-
133
  return review
134
 
135
  def parse_ai_response(response_text: str) -> list[ReviewComment]:
 
5
  import re
6
  import json
7
  from fastapi.responses import HTMLResponse
8
+ import requests
9
 
10
 
11
  # ----------------------------
 
15
  MODEL_NAME = "Salesforce/codegen-350M-mono"
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # Add your Hugging Face API key here
19
+ HF_API_KEY = "REDACTED"
20
+ HF_MODEL_NAME = "bigcode/starcoder" # Replace with the best model for code review
21
+
22
  # ----------------------------
23
  # 2. FastAPI App Initialization
24
  # ----------------------------
 
80
 
81
  def run_ai_inference(diff: str) -> str:
82
  """
83
+ Sends the code diff to Hugging Face Inference API to get the review.
84
  """
85
+ if not HF_API_KEY:
86
+ raise RuntimeError("Hugging Face API key is not set.")
87
+
88
+ # Better prompt for meaningful completions
89
+ prompt = f"""# Code Review
90
+
91
+ def example():
92
+ pass
93
+ # Review: This function should include error handling and documentation.
94
+
95
+ {diff[:400]}
96
+ # Review: This code should include"""
97
+
98
+ headers = {
99
+ "Authorization": f"Bearer {HF_API_KEY}",
100
+ "Content-Type": "application/json"
101
+ }
102
+ payload = {
103
+ "inputs": prompt,
104
+ "parameters": {
105
+ "max_new_tokens": 32,
106
+ "temperature": 0.7,
107
+ "top_p": 0.9
108
+ }
109
+ }
110
+
111
+ response = requests.post(
112
+ f"https://api-inference.huggingface.co/models/{HF_MODEL_NAME}",
113
+ headers=headers,
114
+ json=payload
115
  )
116
+
117
+ if response.status_code != 200:
118
+ raise RuntimeError(f"Hugging Face API error: {response.status_code} {response.text}")
119
+
120
+ response_data = response.json()
121
+ if isinstance(response_data, list) and len(response_data) > 0:
122
+ response_text = response_data[0].get("generated_text", "").strip()
123
+ else:
124
+ response_text = "Unable to generate a meaningful review."
125
+
 
 
 
 
 
 
 
126
  # Clean up the response
127
  response_text = response_text.strip()
128
+
129
+ # Handle different completion patterns
130
+ if response_text.startswith("error handling"):
131
+ review = "Consider adding error handling and input validation."
132
+ elif response_text.startswith("documentation"):
133
+ review = "Consider adding documentation and type hints."
134
+ elif response_text.startswith("input validation"):
135
+ review = "Consider adding input validation and error checks."
136
+ elif response_text.startswith("type hints"):
137
+ review = "Consider adding type hints and documentation."
138
  else:
139
+ # Extract meaningful content
140
  lines = [line.strip() for line in response_text.split('\n') if line.strip()]
141
+ if lines and len(lines[0]) > 3:
142
+ first_line = lines[0]
143
+ # Clean up common artifacts
144
+ if first_line.startswith('#'):
145
+ first_line = first_line[1:].strip()
146
+ if len(first_line) > 10:
147
+ review = f"Consider adding {first_line.lower()}."
148
+ else:
149
+ review = "Consider adding proper documentation and error handling."
150
  else:
151
  review = "Consider adding proper documentation and error handling."
152
+
153
  return review
154
 
155
  def parse_ai_response(response_text: str) -> list[ReviewComment]: