This repo contains LoRA adapter weights.

Model Description

Results

Prompt Approach GSM8k MATH
Zero-Shot CoT 75.81 -

Training procedure

The following bitsandbytes quantization config was used during training:

  • quant_method: bitsandbytes
  • load_in_8bit: False
  • load_in_4bit: True
  • bnb_4bit_quant_type: nf4
  • bnb_4bit_use_double_quant: True
  • bnb_4bit_compute_dtype: float16

LoraConfig params:

  • r: 128
  • lora_alpha: lora_r * 2
  • lora_dropout: 0.05
  • bias: "none"
  • task_type: "CAUSAL_LM"
  • target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

The hyperparameters for the LoRA fine-tuning are listed below:

  • epochs: 3
  • learning_rate: 5e-5
  • batch_size: 256
  • max_grad_norm: 1.0
  • weight_decay: 0.001
  • lr_scheduler_type: "cosine"
  • warmup_ratio: 0.03

Dataset

math_QA dataset is prepared as combination of MetaMathQA and MathInstruct, and some internal data. Refer math_QAaugP

Model Usage

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
from peft import PeftModel

model_path = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype = torch.float16,
    device_map = {"": 0},
)

# Load LoRA and merge
model = PeftModel.from_pretrained(model, "adityasihag/math_QA-Mistral-7B-QLoRA-adapter")
model = model.merge_and_unload()

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

question = """Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x."""

sample_input = f"""Question: {question} \n Answer: """

sample_input_tokenised = tokenizer(sample_input, return_tensors = "pt").to("cuda")

generated_ids = model.generate(
                    **sample_input_tokenised,
                    max_new_tokens = 1024,
                    temperature = 0.3
                )
output = tokenizer.decode(generated_ids[0], skip_special_tokens = True)
print(output)
Sample Input:
Question: Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x. \n Answer: 
Model Output:
Given the linear equation 3(x+2)-x=x+9. 
First, distribute the 3 in the brackets to get 3x + 6 - x = x + 9. 
Simplify the equation to get 2x + 6 = x + 9. 
Next, transpose x from the right side to the left side and from the left side to the right side to get x = 9 - 6. 
Finally, solve for x to get x = 3.

Prompt Template:

Question: <question>
Answer: 

Comparing math_QA models with other SFT LLM models

Model GSM8k Pass@1 MATH Pass@1
LLaMA-2-7B 14.6 2.5
gemma-2b 17.7
LLaMA-2-13B 28.7 3.9
LLaMA-2-34B 42.2 6.24
math_QA-gemma-2B 43.66
gemma-7b 46.4
WizardMath-7B 54.9 10.7
Mistral-7B 35.4
WizardMath-13B 63.9 14.0
MetaMath-7B 66.5 19.8
MetaMath-13B 72.3 22.4
math_QA-Mistral-7B 75.81
Arithmo2-Mistral-7B 76.4 27.2
MetaMath-Mistral-7B 77.7 28.2
DeepSeekMath-Instruct-7B 82.9 46.8
GPT4 92.0 52.9
Downloads last month
20
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.

Model tree for adityasihag/math_QA-Mistral-7B-QLoRA-adapter

Adapter
(1786)
this model

Dataset used to train adityasihag/math_QA-Mistral-7B-QLoRA-adapter