File size: 4,312 Bytes
95a008c 19a006f 95a008c 19a006f 95a008c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .utils import format_prompt_for_hint, format_prompt_for_followup, format_prompt_for_feedback
from peft import PeftModel, PeftConfig
MODEL_NAME = "Salesforce/codegen-350M-mono"
# Use a relative path for the LoRA model, assuming it's in the project root
LORA_PATH = "./fine-tuned-model"
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
# Load and apply the LoRA adapter
model = PeftModel.from_pretrained(
base_model,
LORA_PATH,
torch_dtype="auto",
device_map="auto"
)
# Merge LoRA weights with base model for better inference performance
model = model.merge_and_unload()
def generate_hint(code_snippet, task_description, mode='concise'):
prompt = format_prompt_for_hint(task_description, code_snippet, mode)
print("--- PROMPT SENT TO MODEL ---")
print(prompt)
print("-----------------------------")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("--- RAW MODEL OUTPUT ---")
print(decoded_output)
print("------------------------")
# Extract only what comes after 'HINT:'
if "HINT:" in decoded_output:
hint = decoded_output.split("HINT:", 1)[-1].strip()
for line in hint.splitlines():
if line.strip():
return line.strip()
# Fallback: return first non-empty line not in prompt
lines = [line.strip() for line in decoded_output.splitlines() if line.strip()]
for line in lines:
if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line:
return line
return ""
def generate_feedback(code_snippet, task_description, mode='concise'):
prompt = format_prompt_for_feedback(task_description, code_snippet, mode)
print("--- PROMPT SENT TO MODEL (FEEDBACK) ---")
print(prompt)
print("---------------------------------------")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256 if mode == 'detailed' else 128, do_sample=True, temperature=0.75)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("--- RAW MODEL OUTPUT (FEEDBACK) ---")
print(decoded_output)
print("-----------------------------------")
# Extract only what comes after 'FEEDBACK:'
if "FEEDBACK:" in decoded_output:
feedback = decoded_output.split("FEEDBACK:", 1)[-1].strip()
return feedback
# Fallback: return first non-empty line not in prompt
lines = [line.strip() for line in decoded_output.splitlines() if line.strip()]
for line in lines:
if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line:
return line
return ""
def generate_follow_up(task_description, code_snippet):
prompt = format_prompt_for_followup(task_description, code_snippet)
print("--- PROMPT SENT TO MODEL ---")
print(prompt)
print("-----------------------------")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("--- RAW MODEL OUTPUT ---")
print(decoded_output)
print("------------------------")
# Extract only what comes after the prompt
if prompt in decoded_output:
followup = decoded_output.split(prompt, 1)[-1].strip()
for line in followup.splitlines():
if line.strip():
return line.strip()
# Fallback: return first non-empty line not in prompt
lines = [line.strip() for line in decoded_output.splitlines() if line.strip()]
for line in lines:
if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line:
return line
return ""
|