Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import re | |
| import json | |
| # ---------------------------- | |
| # 1. Configuration | |
| # ---------------------------- | |
| MODEL_NAME = "Salesforce/codegen-350M-mono" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------------------- | |
| # 2. FastAPI App Initialization | |
| # ---------------------------- | |
| app = FastAPI( | |
| title="AI Code Review Service", | |
| description="An API to get AI-powered code reviews for pull request diffs.", | |
| version="1.0.0", | |
| ) | |
| # ---------------------------- | |
| # 3. AI Model Loading | |
| # ---------------------------- | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Loads the model and tokenizer into memory.""" | |
| global model, tokenizer | |
| if model is None: | |
| print(f"Loading model: {MODEL_NAME} on device: {DEVICE}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| ) | |
| print("Model loaded successfully.") | |
| async def startup_event(): | |
| """ | |
| On server startup, we trigger the model loading. | |
| """ | |
| print("Server starting up...") | |
| load_model() | |
| # ---------------------------- | |
| # 4. API Request/Response Models | |
| # ---------------------------- | |
| class ReviewRequest(BaseModel): | |
| diff: str | |
| class ReviewComment(BaseModel): | |
| file_path: str | |
| line_number: int | |
| comment_text: str | |
| class ReviewResponse(BaseModel): | |
| comments: list[ReviewComment] | |
| # ---------------------------- | |
| # 5. The AI Review Logic | |
| # ---------------------------- | |
| def run_ai_inference(diff: str) -> str: | |
| """ | |
| Runs the AI model to get the review. | |
| """ | |
| if not model or not tokenizer: | |
| raise RuntimeError("Model is not loaded.") | |
| # Improved prompt for codegen-350M-mono | |
| prompt = ( | |
| "Below is a Python function. Please provide a code review comment with suggestions for improvement, in natural language. " | |
| "Do not repeat the code.\n" | |
| f"{diff[:800]}\n" | |
| "Review comment:" | |
| ) | |
| encoded = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True, | |
| padding="max_length" | |
| ) | |
| input_ids = encoded["input_ids"] | |
| attention_mask = encoded["attention_mask"] | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| response_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| # Post-process: filter out code-like lines and fallback if needed | |
| review_lines = [line.strip() for line in response_text.strip().split('\n') if line.strip()] | |
| # Filter out lines that look like code | |
| comment_lines = [l for l in review_lines if not l.startswith("def ") and not l.startswith("class ") and not l.endswith(":") and not l.startswith("#")] | |
| review = comment_lines[0] if comment_lines else "Consider adding a docstring and input validation." | |
| return review | |
| def parse_ai_response(response_text: str) -> list[ReviewComment]: | |
| """ | |
| Parses the raw text from the AI to extract the JSON array. | |
| """ | |
| # For codegen-350M-mono, just wrap the review in a single comment | |
| return [ReviewComment( | |
| file_path="code_reviewed.py", | |
| line_number=1, | |
| comment_text=response_text.strip() | |
| )] | |
| # ---------------------------- | |
| # 6. The API Endpoint | |
| # ---------------------------- | |
| async def get_code_review(request: ReviewRequest): | |
| if not request.diff: | |
| raise HTTPException(status_code=400, detail="Diff content cannot be empty.") | |
| import time | |
| start_time = time.time() | |
| print(f"Starting review request at {start_time}") | |
| try: | |
| print("Running AI inference...") | |
| ai_response_text = run_ai_inference(request.diff) | |
| print(f"AI inference completed in {time.time() - start_time:.2f} seconds") | |
| print("Parsing AI response...") | |
| parsed_comments = parse_ai_response(ai_response_text) | |
| print(f"Total processing time: {time.time() - start_time:.2f} seconds") | |
| return ReviewResponse(comments=parsed_comments) | |
| except Exception as e: | |
| print(f"An unexpected error occurred after {time.time() - start_time:.2f} seconds: {e}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while processing the review.") | |
| # ---------------------------- | |
| # 7. Health Check Endpoint | |
| # ---------------------------- | |
| async def health_check(): | |
| return {"status": "ok", "model_loaded": model is not None} |