DrishtiSharma's picture
Update notes.md
334cf54 verified
| Feature/Case | πŸ€— + Standard Attention (Baseline) | πŸ€— + Flash Attention 1 | πŸ€— + Flash Attention 2 | πŸ€— + Unsloth |
|--------------|------------------------------------|------------------------|------------------------|--------------|
| **Dataset** | databricks/databricks-dolly-15k | databricks/databricks-dolly-15k | databricks/databricks-dolly-15k | databricks/databricks-dolly-15k |
| **Model** | NousResearch/Llama-2-7b-hf | NousResearch/Llama-2-7b-hf | NousResearch/Llama-2-7b-hf | unsloth/llama-2-7b |
| **Training Techniques for Model Training Optimization** | QLoRA, Packing | QLoRA, Flash Attention 1, Packing | QLoRA, Flash Attention 2, Packing | QLoRA, Unsloth, Packing |
| **Dependencies for Unsloth and FA**| **NA** | `!pip install -U optimum` | `!pip install -U flash-attn` | ![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/766606bf-d986-4bff-b12b-ec8ae19b5d61)|
|**Model Loading**|`model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=True, device_map="auto")`|`model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=True, device_map="auto")`|`model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=True, device_map="auto", use_flash_attention_2=True)`|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/e9e9515f-e5cc-4e06-9908-c509a8d7ed1a)|
|LoRA|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/e4de78f8-317d-4332-b732-8e52d969d665)| ![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/18857fca-a02b-4106-bbb1-e7ca9ff50e73)|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/b1d192bb-c7ea-40fc-9794-68e48aec9881)|![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/3ad324d1-0e84-4ee9-bd20-b2c0869e1748)|
| **Model Training Setup** | trainer.train()| ![image](https://github.com/DrishtiShrrrma/llama2-7b-flash-atn2-packing-unsloth-neftune-analysis/assets/129742046/5c5741d3-d7a2-4294-9d2c-c100eaf3e884)| trainer.train()| trainer.train() |
| **Trainable Params** | 67,108,864 | 67,108,864 | 67,108,864 | 67,108,864 |
| **Total Params** | 3,567,521,792 | 3,567,521,792 | 3,567,521,792 | 3,567,521,792 |
| **Trainable Percentage (%)** | 1.881 | 1.881 | 1.881 | 1.881 |