shyam-incedoinc's picture
Create README.md
195c8ae
# This is a fine-tuned model, trained on 400+ test scripts, written in Java using `Cucumber` and `Selenium` frameworks.
Base model used is `codellama/CodeLlama-7b-hf`. The dataset used can be found at `shyam-incedoinc/qa-finetune-dataset`.
Training metrics can be seen in the metrics section.
# Training Parameters
```
num_train_epochs=25,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
optim="paged_adamw_32bit",
#save_steps=save_steps,
logging_steps=25,
save_strategy="epoch",
learning_rate=2e-4,
weight_decay=0.001,
fp16=True,
bf16=False,
max_grad_norm=0.3,
warmup_ratio=0.03,
#max_steps=max_steps,
group_by_length=False,
lr_scheduler_type="cosine",
disable_tqdm=False,
report_to="tensorboard",
seed=42
)
LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
```
# Run the below code block for getting inferences from this model.
```
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
hf_model_repo = "shyam-incedoinc/codellama-7b-hf-peft-qlora-finetuned-qa"
# Get the tokenizer
tokenizer = AutoTokenizer.from_pretrained(hf_model_repo)
# Load the model
model = AutoModelForCausalLM.from_pretrained(hf_model_repo, load_in_4bit=True,
torch_dtype=torch.float16,
device_map="auto")
# Load dataset from the hub
hf_data_repo = "shyam-incedoinc/qa-finetune-dataset"
train_dataset = load_dataset(hf_data_repo, split="train")
valid_dataset = load_dataset(hf_data_repo, split="validation")
# Load the sample
sample = valid_dataset[randrange(len(valid_dataset))]['text']
groundtruth = sample.split("### Output:\n")[1]
prompt = sample.split("### Output:\n")[0]+"### Output:\n"
# Generate response
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
outputs = model.generate(input_ids=input_ids, max_new_tokens=1024,
do_sample=True, top_p=0.9, temperature=0.6)
# Print the result
print(f"Generated response:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]}")
print(f"Ground Truth:\n{groundtruth}")
```