Sgridda commited on
Commit
8e65098
·
1 Parent(s): fe2db02

Fix quantization for CPU by using BitsAndBytesConfig

Browse files
Files changed (1) hide show
  1. main.py +22 -57
main.py CHANGED
@@ -1,6 +1,7 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import torch
5
  import re
6
  import json
@@ -9,9 +10,9 @@ import json
9
  # 1. Configuration
10
  # ----------------------------
11
 
12
- # Define the model we want to use.
13
- # We use a 4-bit quantized version ("4bit") for efficiency.
14
  MODEL_NAME = "deepseek-ai/deepseek-coder-6.7b-instruct"
 
 
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  # ----------------------------
@@ -28,8 +29,6 @@ app = FastAPI(
28
  # 3. AI Model Loading
29
  # ----------------------------
30
 
31
- # Use a global variable to hold the model and tokenizer
32
- # This is lazy-loaded on the first request to speed up server startup.
33
  model = None
34
  tokenizer = None
35
 
@@ -39,13 +38,23 @@ def load_model():
39
  if model is None:
40
  print(f"Loading model: {MODEL_NAME} on device: {DEVICE}...")
41
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
42
-
43
- # Load the model with 4-bit quantization to save memory
 
 
 
 
 
 
 
 
 
 
44
  model = AutoModelForCausalLM.from_pretrained(
45
  MODEL_NAME,
46
  trust_remote_code=True,
47
- torch_dtype=torch.bfloat16,
48
- load_in_4bit=True,
49
  )
50
  print("Model loaded successfully.")
51
 
@@ -53,7 +62,6 @@ def load_model():
53
  async def startup_event():
54
  """
55
  On server startup, we trigger the model loading.
56
- This makes the first API call after startup faster.
57
  """
58
  print("Server starting up...")
59
  load_model()
@@ -63,17 +71,14 @@ async def startup_event():
63
  # ----------------------------
64
 
65
  class ReviewRequest(BaseModel):
66
- """The request body for the /review endpoint."""
67
  diff: str
68
 
69
  class ReviewComment(BaseModel):
70
- """A single review comment."""
71
  file_path: str
72
  line_number: int
73
  comment_text: str
74
 
75
  class ReviewResponse(BaseModel):
76
- """The response body for the /review endpoint."""
77
  comments: list[ReviewComment]
78
 
79
  # ----------------------------
@@ -87,37 +92,10 @@ def run_ai_inference(diff: str) -> str:
87
  if not model or not tokenizer:
88
  raise RuntimeError("Model is not loaded.")
89
 
90
- # This is the prompt engineering part. We create a clear instruction
91
- # for the model, telling it exactly what to do and what format to output.
92
  messages = [
93
  {
94
  "role": "system",
95
- "content": """
96
- You are an expert code reviewer. Your task is to analyze a pull request diff and provide constructive feedback.
97
- Analyze the provided diff and identify potential issues, suggest improvements, or point out good practices.
98
- Your feedback should be in the form of review comments.
99
-
100
- IMPORTANT: Respond with a JSON array of comment objects. Each object must have three fields: 'file_path', 'line_number', and 'comment_text'.
101
- The 'file_path' should be the full path of the file being changed.
102
- The 'line_number' must be an integer corresponding to the line number in the *new* version of the file where the comment applies.
103
- The 'comment_text' should be your concise and clear review comment.
104
-
105
- Example response format:
106
- [
107
- {
108
- "file_path": "src/utils/helpers.py",
109
- "line_number": 42,
110
- "comment_text": "This function could be simplified by using a list comprehension."
111
- },
112
- {
113
- "file_path": "README.md",
114
- "line_number": 12,
115
- "comment_text": "There is a typo in this sentence."
116
- }
117
- ]
118
-
119
- Do not add any introductory text or explanations outside of the JSON array.
120
- """
121
  },
122
  {
123
  "role": "user",
@@ -125,23 +103,20 @@ Do not add any introductory text or explanations outside of the JSON array.
125
  }
126
  ]
127
 
128
- inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
 
129
 
130
- # Generate the response from the model
131
  outputs = model.generate(inputs, max_new_tokens=1024, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
132
 
133
- # Decode the output and clean it up
134
  response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
135
  return response_text.strip()
136
 
137
  def parse_ai_response(response_text: str) -> list[ReviewComment]:
138
  """
139
  Parses the raw text from the AI to extract the JSON array.
140
- This function is robust against the AI adding extra text before or after the JSON.
141
  """
142
  print(f"Raw AI Response:\n---\n{response_text}\n---")
143
 
144
- # Find the start and end of the JSON array
145
  json_match = re.search(r'\[.*\]', response_text, re.DOTALL)
146
  if not json_match:
147
  print("Warning: Could not find a JSON array in the AI response.")
@@ -151,7 +126,6 @@ def parse_ai_response(response_text: str) -> list[ReviewComment]:
151
 
152
  try:
153
  comments_data = json.loads(json_string)
154
- # Validate the structure of the parsed data
155
  validated_comments = [ReviewComment(**item) for item in comments_data]
156
  return validated_comments
157
  except (json.JSONDecodeError, TypeError, KeyError) as e:
@@ -165,20 +139,12 @@ def parse_ai_response(response_text: str) -> list[ReviewComment]:
165
 
166
  @app.post("/review", response_model=ReviewResponse)
167
  async def get_code_review(request: ReviewRequest):
168
- """
169
- Receives a code diff, gets a review from the AI model,
170
- and returns structured review comments.
171
- """
172
  if not request.diff:
173
  raise HTTPException(status_code=400, detail="Diff content cannot be empty.")
174
 
175
  try:
176
- # 1. Run the AI model
177
  ai_response_text = run_ai_inference(request.diff)
178
-
179
- # 2. Parse the AI's response into structured objects
180
  parsed_comments = parse_ai_response(ai_response_text)
181
-
182
  return ReviewResponse(comments=parsed_comments)
183
 
184
  except Exception as e:
@@ -191,5 +157,4 @@ async def get_code_review(request: ReviewRequest):
191
 
192
  @app.get("/health")
193
  async def health_check():
194
- """A simple endpoint to confirm the server is running."""
195
- return {"status": "ok", "model_loaded": model is not None}
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ # We now import BitsAndBytesConfig to specify our quantization settings
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
  import torch
6
  import re
7
  import json
 
10
  # 1. Configuration
11
  # ----------------------------
12
 
 
 
13
  MODEL_NAME = "deepseek-ai/deepseek-coder-6.7b-instruct"
14
+ # The device will be automatically handled by device_map="auto"
15
+ # but we can keep this for logging.
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # ----------------------------
 
29
  # 3. AI Model Loading
30
  # ----------------------------
31
 
 
 
32
  model = None
33
  tokenizer = None
34
 
 
38
  if model is None:
39
  print(f"Loading model: {MODEL_NAME} on device: {DEVICE}...")
40
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
41
+
42
+ # FIX: Define the quantization configuration for 4-bit loading.
43
+ # We explicitly set bnb_4bit_quant_type to "nf4", which is required for CPU execution.
44
+ quantization_config = BitsAndBytesConfig(
45
+ load_in_4bit=True,
46
+ bnb_4bit_quant_type="nf4",
47
+ bnb_4bit_compute_dtype=torch.bfloat16,
48
+ bnb_4bit_use_double_quant=False,
49
+ )
50
+
51
+ # Load the model with the specified quantization config.
52
+ # We also use device_map="auto" to let transformers handle device placement.
53
  model = AutoModelForCausalLM.from_pretrained(
54
  MODEL_NAME,
55
  trust_remote_code=True,
56
+ quantization_config=quantization_config,
57
+ device_map="auto", # This is crucial for bitsandbytes to work correctly
58
  )
59
  print("Model loaded successfully.")
60
 
 
62
  async def startup_event():
63
  """
64
  On server startup, we trigger the model loading.
 
65
  """
66
  print("Server starting up...")
67
  load_model()
 
71
  # ----------------------------
72
 
73
  class ReviewRequest(BaseModel):
 
74
  diff: str
75
 
76
  class ReviewComment(BaseModel):
 
77
  file_path: str
78
  line_number: int
79
  comment_text: str
80
 
81
  class ReviewResponse(BaseModel):
 
82
  comments: list[ReviewComment]
83
 
84
  # ----------------------------
 
92
  if not model or not tokenizer:
93
  raise RuntimeError("Model is not loaded.")
94
 
 
 
95
  messages = [
96
  {
97
  "role": "system",
98
+ "content": """You are an expert code reviewer. Your task is to analyze a pull request diff and provide constructive feedback.\nAnalyze the provided diff and identify potential issues, suggest improvements, or point out good practices.\n\nIMPORTANT: Respond with a JSON array of comment objects. Each object must have three fields: 'file_path', 'line_number', and 'comment_text'.\nThe 'file_path' should be the full path of the file being changed.\nThe 'line_number' must be an integer corresponding to the line number in the *new* version of the file where the comment applies.\nThe 'comment_text' should be your concise and clear review comment.\n\nExample response format:\n[\n {\n "file_path": "src/utils/helpers.py",\n "line_number": 42,\n "comment_text": "This function could be simplified by using a list comprehension."\n }\n]\n\nDo not add any introductory text or explanations outside of the JSON array.\n"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  },
100
  {
101
  "role": "user",
 
103
  }
104
  ]
105
 
106
+ # Note: We don't need to manually move inputs to a device when using device_map="auto"
107
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
108
 
 
109
  outputs = model.generate(inputs, max_new_tokens=1024, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
110
 
 
111
  response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
112
  return response_text.strip()
113
 
114
  def parse_ai_response(response_text: str) -> list[ReviewComment]:
115
  """
116
  Parses the raw text from the AI to extract the JSON array.
 
117
  """
118
  print(f"Raw AI Response:\n---\n{response_text}\n---")
119
 
 
120
  json_match = re.search(r'\[.*\]', response_text, re.DOTALL)
121
  if not json_match:
122
  print("Warning: Could not find a JSON array in the AI response.")
 
126
 
127
  try:
128
  comments_data = json.loads(json_string)
 
129
  validated_comments = [ReviewComment(**item) for item in comments_data]
130
  return validated_comments
131
  except (json.JSONDecodeError, TypeError, KeyError) as e:
 
139
 
140
  @app.post("/review", response_model=ReviewResponse)
141
  async def get_code_review(request: ReviewRequest):
 
 
 
 
142
  if not request.diff:
143
  raise HTTPException(status_code=400, detail="Diff content cannot be empty.")
144
 
145
  try:
 
146
  ai_response_text = run_ai_inference(request.diff)
 
 
147
  parsed_comments = parse_ai_response(ai_response_text)
 
148
  return ReviewResponse(comments=parsed_comments)
149
 
150
  except Exception as e:
 
157
 
158
  @app.get("/health")
159
  async def health_check():
160
+ return {"status": "ok", "model_loaded": model is not None}