AICodeReviewer / app.py
shekkari21's picture
commit
df77992
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from pydantic import BaseModel
from typing import List
import os # ← add this
from clearml import Model, Task
import torch
from configs import add_args
from models import build_or_load_gen_model
import argparse
from argparse import Namespace
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
# maximum token length for inputs
MAX_SOURCE_LENGTH = 512
# Load endpoints & creds
CLEARML_API_HOST = os.environ["CLEARML_API_HOST"]
CLEARML_WEB_HOST = os.environ["CLEARML_WEB_HOST"]
CLEARML_FILES_HOST = os.environ["CLEARML_FILES_HOST"]
CLEARML_ACCESS_KEY = os.environ["CLEARML_API_ACCESS_KEY"]
CLEARML_SECRET_KEY = os.environ["CLEARML_API_SECRET_KEY"]
# Apply to SDK
Task.set_credentials(
api_host=CLEARML_API_HOST,
web_host=CLEARML_WEB_HOST,
files_host=CLEARML_FILES_HOST,
key=CLEARML_ACCESS_KEY,
secret=CLEARML_SECRET_KEY,
)
def pad_assert(tokenizer, source_ids):
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
pad_len = MAX_SOURCE_LENGTH - len(source_ids)
source_ids += [tokenizer.pad_id] * pad_len
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
return source_ids
# Encode code content and comment into model input
def encode_diff(tokenizer, code, comment):
# Tokenize code file content
code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
# Tokenize comment
comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
# Concatenate: [BOS] + code + [EOS] + [msg_id] + comment
source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id]
source_ids += [tokenizer.msg_id] + comment_ids
# Pad/truncate to fixed length
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
pad_len = MAX_SOURCE_LENGTH - len(source_ids)
source_ids += [tokenizer.pad_id] * pad_len
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
return source_ids
# Load base model architecture and tokenizer from HuggingFace
BASE_MODEL_NAME = "microsoft/codereviewer"
args = Namespace(
model_name_or_path=BASE_MODEL_NAME,
load_model_path=None,
)
print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}")
config, base_model, tokenizer = build_or_load_gen_model(args)
print("Base model architecture and tokenizer loaded.")
# Download the fine-tuned weights via ClearML using your injected creds
task = Task.get_task(task_id="2d65a9e213ea49a9b37e1cc89a2b7ff0")
extracted_adapter_dir = task.artifacts["lora-adapter"].get_local_copy() # This is the directory path
actual_weights_file_path = os.path.join(extracted_adapter_dir, "pytorch_model.bin") # Path to the actual model file
print(f"Fine-tuned adapter weights downloaded and extracted to directory: {extracted_adapter_dir}")
print(f"Loading fine-tuned adapter weights from file: {actual_weights_file_path}")
# Create LoRA configuration matching the fine-tuned checkpoint
lora_cfg = LoraConfig(
r=64,
lora_alpha=128,
target_modules=["q", "wo", "wi", "v", "o", "k"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_2_SEQ_LM"
)
# Wrap base model with PEFT LoRA
peft_model = get_peft_model(base_model, lora_cfg)
# Load adapter-only weights and merge into base
adapter_state = torch.load(actual_weights_file_path, map_location="cpu")
peft_model.load_state_dict(adapter_state, strict=False)
model = peft_model.merge_and_unload()
print("Merged base model with LoRA adapters.")
model.to("cpu")
model.eval()
print("Model ready for inference.")
app = FastAPI()
last_payload = {"comment": "", "files": []}
last_infer_result = {"generated_code": ""}
class FileContent(BaseModel):
filename: str
content: str
class PRPayload(BaseModel):
comment: str
files: List[FileContent]
class InferenceRequest(BaseModel):
comment: str
files: List[FileContent]
@app.get("/")
def root():
return {"message": "FastAPI PR comment service is running"}
@app.post("/pr-comments")
async def receive_pr_comment(payload: PRPayload):
global last_payload
last_payload = payload.dict()
# Return the received payload as JSON and also redirect to /show
return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"})
@app.get("/show", response_class=HTMLResponse)
def show_last_comment():
html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>"
for file in last_payload["files"]:
html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>"
return html
@app.post("/infer")
async def infer(request: InferenceRequest):
global last_infer_result
print("[DEBUG] Received /infer request with:", request.dict())
code = request.files[0].content if request.files else ""
source_ids = encode_diff(tokenizer, code, request.comment)
# print("[DEBUG] source_ids:", source_ids)
#tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids]
#print("[DEBUG] tokens:", tokens)
inputs = torch.tensor([source_ids], dtype=torch.long)
inputs_mask = inputs.ne(tokenizer.pad_id)
preds = model.generate(
inputs,
attention_mask=inputs_mask,
use_cache=True,
num_beams=5,
early_stopping=True,
max_length=100,
num_return_sequences=1
)
pred = preds[0].cpu().numpy()
pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
# Replace <add> markers with newlines and strip whitespace
pred_nl = "\n".join([seg.strip() for seg in pred_nl.split("<add>") if seg.strip()])
last_infer_result = {"generated_code": pred_nl}
return last_infer_result
@app.get("/show-infer", response_class=HTMLResponse)
def show_infer_result():
html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>"
return html
if __name__ == "__main__":
# Place any CLI/training logic here if needed
# This block is NOT executed when running with uvicorn
pass