CodeLlama-70B_for_NTR
We fine-tuned CodeLlama-70B on Transfer_dataset under the NTR framework for APR research.
Model Use
To use this model, please make sure to install transformers, peft, bitsandbytes, and accelerate.
pip install transformers
pip install peft
pip install bitsandbytes
pip install accelerate
Then, please run the following script to merge the adapter into the CodeLlama.
bash merge.sh
Finally, you can load the model to generate patches for buggy code.
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("CodeLlama-70B_for_NTR/Epoch_1/-merged", use_auth_token=True)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
"CodeLlama-70B_for_NTR/Epoch_1/-merged",
quantization_config=nf4_config,
device_map='auto'
)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, lora_config)
# a bug-fix pairs
buggy_code = "
public MultiplePiePlot(CategoryDataset dataset){
super();
// bug_start
this.dataset=dataset;
// bug_end
PiePlot piePlot=new PiePlot(null);
this.pieChart=new JFreeChart(piePlot);
this.pieChart.removeLegend();
this.dataExtractOrder=TableOrder.BY_COLUMN;
this.pieChart.setBackgroundPaint(null);
TextTitle seriesTitle=new TextTitle("Series Title",new Font("SansSerif",Font.BOLD,12));
seriesTitle.setPosition(RectangleEdge.BOTTOM);
this.pieChart.setTitle(seriesTitle);
this.aggregatedItemsKey="Other";
this.aggregatedItemsPaint=Color.lightGray;
this.sectionPaints=new HashMap();
}
"
repair_template = "OtherTemplate"
fixed_code = "
// fix_start
setDataset(dataset);
// fix_end
"
# model inference
B_INST, E_INST = "[INST]", "[/INST]"
input_text = tokenizer.bos_token + B_INST +'\n[bug_function]\n' + buggy_code + '\n[fix_template]\n' + repair_template + '\n[fix_code]\n' + E_INST
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=256,
num_beams=10,
num_return_sequences=10,
early_stopping=True,
pad_token_id=eos_id,
eos_token_id=eos_id
)
for generated_id in generated_ids:
generated_text = tokenizer.decode(generated_id, skip_special_tokens=False)
patch = generated_text.split(E_INST)[1]
patch = patch.replace(tokenizer.eos_token,'')
print(patch)
Model Details
*Note: Use of this model is governed by the Meta license. Meta developed and publicly released the Code Llama family of large language models (LLMs).