Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Form | |
| from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from typing import List | |
| import os # ← add this | |
| from clearml import Model, Task | |
| import torch | |
| from configs import add_args | |
| from models import build_or_load_gen_model | |
| import argparse | |
| from argparse import Namespace | |
| from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig | |
| # maximum token length for inputs | |
| MAX_SOURCE_LENGTH = 512 | |
| # Load endpoints & creds | |
| CLEARML_API_HOST = os.environ["CLEARML_API_HOST"] | |
| CLEARML_WEB_HOST = os.environ["CLEARML_WEB_HOST"] | |
| CLEARML_FILES_HOST = os.environ["CLEARML_FILES_HOST"] | |
| CLEARML_ACCESS_KEY = os.environ["CLEARML_API_ACCESS_KEY"] | |
| CLEARML_SECRET_KEY = os.environ["CLEARML_API_SECRET_KEY"] | |
| # Apply to SDK | |
| Task.set_credentials( | |
| api_host=CLEARML_API_HOST, | |
| web_host=CLEARML_WEB_HOST, | |
| files_host=CLEARML_FILES_HOST, | |
| key=CLEARML_ACCESS_KEY, | |
| secret=CLEARML_SECRET_KEY, | |
| ) | |
| def pad_assert(tokenizer, source_ids): | |
| source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
| source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
| pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
| source_ids += [tokenizer.pad_id] * pad_len | |
| assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
| return source_ids | |
| # Encode code content and comment into model input | |
| def encode_diff(tokenizer, code, comment): | |
| # Tokenize code file content | |
| code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
| # Tokenize comment | |
| comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
| # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment | |
| source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id] | |
| source_ids += [tokenizer.msg_id] + comment_ids | |
| # Pad/truncate to fixed length | |
| source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
| source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
| pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
| source_ids += [tokenizer.pad_id] * pad_len | |
| assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
| return source_ids | |
| # Load base model architecture and tokenizer from HuggingFace | |
| BASE_MODEL_NAME = "microsoft/codereviewer" | |
| args = Namespace( | |
| model_name_or_path=BASE_MODEL_NAME, | |
| load_model_path=None, | |
| ) | |
| print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}") | |
| config, base_model, tokenizer = build_or_load_gen_model(args) | |
| print("Base model architecture and tokenizer loaded.") | |
| # Download the fine-tuned weights via ClearML using your injected creds | |
| task = Task.get_task(task_id="2d65a9e213ea49a9b37e1cc89a2b7ff0") | |
| extracted_adapter_dir = task.artifacts["lora-adapter"].get_local_copy() # This is the directory path | |
| actual_weights_file_path = os.path.join(extracted_adapter_dir, "pytorch_model.bin") # Path to the actual model file | |
| print(f"Fine-tuned adapter weights downloaded and extracted to directory: {extracted_adapter_dir}") | |
| print(f"Loading fine-tuned adapter weights from file: {actual_weights_file_path}") | |
| # Create LoRA configuration matching the fine-tuned checkpoint | |
| lora_cfg = LoraConfig( | |
| r=64, | |
| lora_alpha=128, | |
| target_modules=["q", "wo", "wi", "v", "o", "k"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM" | |
| ) | |
| # Wrap base model with PEFT LoRA | |
| peft_model = get_peft_model(base_model, lora_cfg) | |
| # Load adapter-only weights and merge into base | |
| adapter_state = torch.load(actual_weights_file_path, map_location="cpu") | |
| peft_model.load_state_dict(adapter_state, strict=False) | |
| model = peft_model.merge_and_unload() | |
| print("Merged base model with LoRA adapters.") | |
| model.to("cpu") | |
| model.eval() | |
| print("Model ready for inference.") | |
| app = FastAPI() | |
| last_payload = {"comment": "", "files": []} | |
| last_infer_result = {"generated_code": ""} | |
| class FileContent(BaseModel): | |
| filename: str | |
| content: str | |
| class PRPayload(BaseModel): | |
| comment: str | |
| files: List[FileContent] | |
| class InferenceRequest(BaseModel): | |
| comment: str | |
| files: List[FileContent] | |
| def root(): | |
| return {"message": "FastAPI PR comment service is running"} | |
| async def receive_pr_comment(payload: PRPayload): | |
| global last_payload | |
| last_payload = payload.dict() | |
| # Return the received payload as JSON and also redirect to /show | |
| return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"}) | |
| def show_last_comment(): | |
| html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>" | |
| for file in last_payload["files"]: | |
| html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>" | |
| return html | |
| async def infer(request: InferenceRequest): | |
| global last_infer_result | |
| print("[DEBUG] Received /infer request with:", request.dict()) | |
| code = request.files[0].content if request.files else "" | |
| source_ids = encode_diff(tokenizer, code, request.comment) | |
| # print("[DEBUG] source_ids:", source_ids) | |
| #tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids] | |
| #print("[DEBUG] tokens:", tokens) | |
| inputs = torch.tensor([source_ids], dtype=torch.long) | |
| inputs_mask = inputs.ne(tokenizer.pad_id) | |
| preds = model.generate( | |
| inputs, | |
| attention_mask=inputs_mask, | |
| use_cache=True, | |
| num_beams=5, | |
| early_stopping=True, | |
| max_length=100, | |
| num_return_sequences=1 | |
| ) | |
| pred = preds[0].cpu().numpy() | |
| pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| # Replace <add> markers with newlines and strip whitespace | |
| pred_nl = "\n".join([seg.strip() for seg in pred_nl.split("<add>") if seg.strip()]) | |
| last_infer_result = {"generated_code": pred_nl} | |
| return last_infer_result | |
| def show_infer_result(): | |
| html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>" | |
| return html | |
| if __name__ == "__main__": | |
| # Place any CLI/training logic here if needed | |
| # This block is NOT executed when running with uvicorn | |
| pass |