LLM4APR's picture
Update README.md
57b0f96 verified
|
raw
history blame
2.97 kB

CodeLlama-70B_for_NMT

We fine-tuned CodeLlama-70B on Transfer_dataset under the NMT workflow for APR research.

Model Use

To use this model, please make sure to install transformers, peft, bitsandbytes, and accelerate.

pip install transformers
pip install peft
pip install bitsandbytes
pip install accelerate

Then, please run the following script to merge the adapter into the CodeLlama.

bash merge.sh

Finally, you can load the model to generate patches for buggy code.

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch


# load model and tokenizer

tokenizer = AutoTokenizer.from_pretrained("CodeLlama-70B_for_NMT/Epoch_1/-merged", use_auth_token=True)

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    "CodeLlama-70B_for_NMT/Epoch_1/-merged",
    quantization_config=nf4_config,
    device_map='auto'
)    

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
)

model = get_peft_model(model, lora_config)


# a bug-fix pairs

buggy_code = "
/*
 * Evaluate whether the given number n can be written as the sum of exactly 4 positive even numbers
    Example
    is_equal_to_sum_even(4) == False
    is_equal_to_sum_even(6) == False
    is_equal_to_sum_even(8) == True
 */
public class IS_EQUAL_TO_SUM_EVEN {
    public static boolean is_equal_to_sum_even(int n) {
// bug_start
        return ((n * 2 == 1) ^ (n < 8));
// bug_end
    }
}
"

fixed_code = "
// fix_start
        return ((n % 2 == 0) && (n >= 8));
// fix_end
"

# model inference

B_INST, E_INST = "[INST]", "[/INST]"
input_text =  tokenizer.bos_token + B_INST +'\n[bug_function]\n' + buggy_code + '\n[fix_code]\n' + E_INST
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)

eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
generated_ids = model.generate(
    input_ids=input_ids,
    max_new_tokens=256,
    num_beams=10,
    num_return_sequences=10,
    early_stopping=True,
    pad_token_id=eos_id,
    eos_token_id=eos_id
)

for generated_id in generated_ids:
    generated_text = tokenizer.decode(generated_id, skip_special_tokens=False)
    patch = generated_text.split(E_INST)[1]
    patch = text.replace(tokenizer.eos_token,'')
    print(patch)

Model Details

*Note: Use of this model is governed by the Meta license. Meta developed and publicly released the Code Llama family of large language models (LLMs).