|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
from .utils import format_prompt_for_hint, format_prompt_for_followup, format_prompt_for_feedback |
|
from peft import PeftModel, PeftConfig |
|
|
|
MODEL_NAME = "Salesforce/codegen-350M-mono" |
|
|
|
LORA_PATH = "./fine-tuned-model" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
base_model, |
|
LORA_PATH, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
|
|
|
|
model = model.merge_and_unload() |
|
|
|
def generate_hint(code_snippet, task_description, mode='concise'): |
|
prompt = format_prompt_for_hint(task_description, code_snippet, mode) |
|
print("--- PROMPT SENT TO MODEL ---") |
|
print(prompt) |
|
print("-----------------------------") |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7) |
|
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print("--- RAW MODEL OUTPUT ---") |
|
print(decoded_output) |
|
print("------------------------") |
|
|
|
if "HINT:" in decoded_output: |
|
hint = decoded_output.split("HINT:", 1)[-1].strip() |
|
for line in hint.splitlines(): |
|
if line.strip(): |
|
return line.strip() |
|
|
|
lines = [line.strip() for line in decoded_output.splitlines() if line.strip()] |
|
for line in lines: |
|
if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line: |
|
return line |
|
return "" |
|
|
|
def generate_feedback(code_snippet, task_description, mode='concise'): |
|
prompt = format_prompt_for_feedback(task_description, code_snippet, mode) |
|
print("--- PROMPT SENT TO MODEL (FEEDBACK) ---") |
|
print(prompt) |
|
print("---------------------------------------") |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate(**inputs, max_new_tokens=256 if mode == 'detailed' else 128, do_sample=True, temperature=0.75) |
|
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print("--- RAW MODEL OUTPUT (FEEDBACK) ---") |
|
print(decoded_output) |
|
print("-----------------------------------") |
|
|
|
if "FEEDBACK:" in decoded_output: |
|
feedback = decoded_output.split("FEEDBACK:", 1)[-1].strip() |
|
return feedback |
|
|
|
lines = [line.strip() for line in decoded_output.splitlines() if line.strip()] |
|
for line in lines: |
|
if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line: |
|
return line |
|
return "" |
|
|
|
def generate_follow_up(task_description, code_snippet): |
|
prompt = format_prompt_for_followup(task_description, code_snippet) |
|
print("--- PROMPT SENT TO MODEL ---") |
|
print(prompt) |
|
print("-----------------------------") |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7) |
|
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print("--- RAW MODEL OUTPUT ---") |
|
print(decoded_output) |
|
print("------------------------") |
|
|
|
if prompt in decoded_output: |
|
followup = decoded_output.split(prompt, 1)[-1].strip() |
|
for line in followup.splitlines(): |
|
if line.strip(): |
|
return line.strip() |
|
|
|
lines = [line.strip() for line in decoded_output.splitlines() if line.strip()] |
|
for line in lines: |
|
if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line: |
|
return line |
|
return "" |
|
|