|
| Feature/Case | π€ + Standard Attention (Baseline) | π€ + Flash Attention 1 | π€ + Flash Attention 2 | π€ + Unsloth | |
|
|--------------|------------------------------------|------------------------|------------------------|--------------| |
|
| **Dataset** | databricks/databricks-dolly-15k | databricks/databricks-dolly-15k | databricks/databricks-dolly-15k | databricks/databricks-dolly-15k | |
|
| **Model** | NousResearch/Llama-2-7b-hf | NousResearch/Llama-2-7b-hf | NousResearch/Llama-2-7b-hf | unsloth/llama-2-7b | |
|
| **Training Techniques for Model Training Optimization** | QLoRA, Packing | QLoRA, Flash Attention 1, Packing | QLoRA, Flash Attention 2, Packing | QLoRA, Unsloth, Packing | |
|
| **Dependencies for Unsloth and FA**| **NA** | `!pip install -U optimum` | `!pip install -U flash-attn` | ![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/766606bf-d986-4bff-b12b-ec8ae19b5d61)| |
|
|**Model Loading**|`model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=True, device_map="auto")`|`model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=True, device_map="auto")`|`model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=True, device_map="auto", use_flash_attention_2=True)`|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/e9e9515f-e5cc-4e06-9908-c509a8d7ed1a)| |
|
|LoRA|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/e4de78f8-317d-4332-b732-8e52d969d665)| ![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/18857fca-a02b-4106-bbb1-e7ca9ff50e73)|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/b1d192bb-c7ea-40fc-9794-68e48aec9881)|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/3ad324d1-0e84-4ee9-bd20-b2c0869e1748)| |
|
| **Model Training Setup** | trainer.train()| ![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/5c5741d3-d7a2-4294-9d2c-c100eaf3e884)| trainer.train()| trainer.train() | |
|
| **Trainable Params** | 67,108,864 | 67,108,864 | 67,108,864 | 67,108,864 | |
|
| **Total Params** | 3,567,521,792 | 3,567,521,792 | 3,567,521,792 | 3,567,521,792 | |
|
| **Trainable Percentage (%)** | 1.881 | 1.881 | 1.881 | 1.881 | |