File size: 4,312 Bytes
95a008c
 
 
 
 
 
19a006f
 
95a008c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a006f
95a008c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .utils import format_prompt_for_hint, format_prompt_for_followup, format_prompt_for_feedback
from peft import PeftModel, PeftConfig

MODEL_NAME = "Salesforce/codegen-350M-mono"
# Use a relative path for the LoRA model, assuming it's in the project root
LORA_PATH = "./fine-tuned-model"

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)

# Load and apply the LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    LORA_PATH,
    torch_dtype="auto",
    device_map="auto"
)

# Merge LoRA weights with base model for better inference performance
model = model.merge_and_unload()

def generate_hint(code_snippet, task_description, mode='concise'):
    prompt = format_prompt_for_hint(task_description, code_snippet, mode)
    print("--- PROMPT SENT TO MODEL ---")
    print(prompt)
    print("-----------------------------")
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs,  max_new_tokens=128, do_sample=True, temperature=0.7)
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("--- RAW MODEL OUTPUT ---")
    print(decoded_output)
    print("------------------------")
    # Extract only what comes after 'HINT:'
    if "HINT:" in decoded_output:
        hint = decoded_output.split("HINT:", 1)[-1].strip()
        for line in hint.splitlines():
            if line.strip():
                return line.strip()
    # Fallback: return first non-empty line not in prompt
    lines = [line.strip() for line in decoded_output.splitlines() if line.strip()]
    for line in lines:
        if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line:
            return line
    return ""

def generate_feedback(code_snippet, task_description, mode='concise'):
    prompt = format_prompt_for_feedback(task_description, code_snippet, mode)
    print("--- PROMPT SENT TO MODEL (FEEDBACK) ---")
    print(prompt)
    print("---------------------------------------")
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=256 if mode == 'detailed' else 128, do_sample=True, temperature=0.75)
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("--- RAW MODEL OUTPUT (FEEDBACK) ---")
    print(decoded_output)
    print("-----------------------------------")
    # Extract only what comes after 'FEEDBACK:'
    if "FEEDBACK:" in decoded_output:
        feedback = decoded_output.split("FEEDBACK:", 1)[-1].strip()
        return feedback
    # Fallback: return first non-empty line not in prompt
    lines = [line.strip() for line in decoded_output.splitlines() if line.strip()]
    for line in lines:
        if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line:
            return line
    return ""

def generate_follow_up(task_description, code_snippet):
    prompt = format_prompt_for_followup(task_description, code_snippet)
    print("--- PROMPT SENT TO MODEL ---")
    print(prompt)
    print("-----------------------------")
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs,  max_new_tokens=128, do_sample=True, temperature=0.7)
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("--- RAW MODEL OUTPUT ---")
    print(decoded_output)
    print("------------------------")
    # Extract only what comes after the prompt
    if prompt in decoded_output:
        followup = decoded_output.split(prompt, 1)[-1].strip()
        for line in followup.splitlines():
            if line.strip():
                return line.strip()
    # Fallback: return first non-empty line not in prompt
    lines = [line.strip() for line in decoded_output.splitlines() if line.strip()]
    for line in lines:
        if "Task Description" not in line and "User's Code" not in line and "AI-HR Assistant" not in line:
            return line
    return ""