LLM4APR's picture
Update README.md
3d9e71c verified
---
license: llama2
pipeline_tag: text-generation
tags:
- code
- automated program repair
---
# CodeLlama-70B_for_NTR
We fine-tuned [CodeLlama-70B](https://huggingface.co/codellama/CodeLlama-70b-hf) on [Transfer_dataset](https://drive.google.com/drive/folders/1F1BPfTxHDGX-OCBthudCbu_6Qvcg_fbP?usp=drive_link) under the [NTR](https://sites.google.com/view/neuraltemplaterepair) framework for APR research.
## Model Use
To use this model, please make sure to install transformers, peft, bitsandbytes, and accelerate.
```bash
pip install transformers
pip install peft
pip install bitsandbytes
pip install accelerate
```
Then, please run the following script to merge the adapter into the CodeLlama.
```bash
bash merge.sh
```
Finally, you can load the model to generate patches for buggy code.
```python
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("CodeLlama-70B_for_NTR/Epoch_1/-merged", use_auth_token=True)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
"CodeLlama-70B_for_NTR/Epoch_1/-merged",
quantization_config=nf4_config,
device_map='auto'
)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, lora_config)
# a bug-fix pairs
buggy_code = "
public MultiplePiePlot(CategoryDataset dataset){
super();
// bug_start
this.dataset=dataset;
// bug_end
PiePlot piePlot=new PiePlot(null);
this.pieChart=new JFreeChart(piePlot);
this.pieChart.removeLegend();
this.dataExtractOrder=TableOrder.BY_COLUMN;
this.pieChart.setBackgroundPaint(null);
TextTitle seriesTitle=new TextTitle("Series Title",new Font("SansSerif",Font.BOLD,12));
seriesTitle.setPosition(RectangleEdge.BOTTOM);
this.pieChart.setTitle(seriesTitle);
this.aggregatedItemsKey="Other";
this.aggregatedItemsPaint=Color.lightGray;
this.sectionPaints=new HashMap();
}
"
repair_template = "OtherTemplate"
fixed_code = "
// fix_start
setDataset(dataset);
// fix_end
"
# model inference
B_INST, E_INST = "[INST]", "[/INST]"
input_text = tokenizer.bos_token + B_INST +'\n[bug_function]\n' + buggy_code + '\n[fix_template]\n' + repair_template + '\n[fix_code]\n' + E_INST
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=256,
num_beams=10,
num_return_sequences=10,
early_stopping=True,
pad_token_id=eos_id,
eos_token_id=eos_id
)
for generated_id in generated_ids:
generated_text = tokenizer.decode(generated_id, skip_special_tokens=False)
patch = generated_text.split(E_INST)[1]
patch = patch.replace(tokenizer.eos_token,'')
print(patch)
```
## Model Details
*Note: Use of this model is governed by the Meta license. Meta developed and publicly released the Code Llama family of large language models (LLMs).