File size: 15,386 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 |
{
"cells": [
{
"cell_type": "markdown",
"id": "a45b25c3-08b2-4a7e-b0cd-67293f15c307",
"metadata": {},
"source": [
"# Optimizing Hugging Face Models with Parameter Efficient Fine-Tuning (PEFT)\n",
"\n",
"NeMo 2.0 enables users to perform Supervised Fine-Tuning (SFT) and Parameter-Efficient Fine-Tuning (PEFT) using Hugging Face (HF) Large Language Models (LLMs). It utilizes HF's auto classes to download and load transformer models, and wraps these models as Lightning modules to execute tasks like SFT and PEFT. The goal of this feature is to provide day-0 support for the models available in HF.\n",
"\n",
"[AutoModel](https://huggingface.co/docs/transformers/en/model_doc/auto) is the generic model class that is instantiated as one of the model classes from the library when created with the from_pretrained() class method. There are many AutoModel classes in HF, each covering a specific group of transformer model architectures. The AutoModel class primarily loads the base transformer model that converts embeddings to hidden states. For example, a specific AutoModel class like AutoModelForCausalLM includes a causal language modeling head on top of the base model.\n",
"\n",
"NeMo 2.0 includes wrapper classes for these HF AutoModel classes, making them runnable in NeMo pretraining, SFT, and PEFT workflows by converting them into Lightning modules. Due to the large number of AutoModel classes, NeMo 2.0 currently includes only the widely used auto classes.\n",
"\n",
"In this notebook, we will demonstrate a PEFT training example on how to perform PEFT with Hugging Face LLMs to make the models more performant on a specific task. We will focus on the models that can be loaded using the HF's `AutoModelForCausalLM` class.\n",
"\n",
"<font color='red'>NOTE:</font> Due to the limitations of the Jupyter Notebook, the example in this notebook works only on a single GPU. However, if you move the code to a script, you can run it on multiple GPUs. If you are interested in running a multi-GPU example using the Jupyter Notebook, please check the SFT example in NeMo-Run."
]
},
{
"cell_type": "markdown",
"id": "63a50bad-f356-4076-8c5c-66b4481029dc",
"metadata": {},
"source": [
"## Step 1: Import Modules and Prepare the Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28e16913-6a08-4ad8-835e-311fbb5af01d",
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"import fiddle as fdl\n",
"import lightning.pytorch as pl\n",
"from lightning.pytorch.loggers import WandbLogger\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from nemo import lightning as nl\n",
"from nemo.collections import llm\n",
"from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform\n",
"from nemo.lightning import NeMoLogger"
]
},
{
"cell_type": "markdown",
"id": "5cfe3c7d-9d36-47d2-9107-361025d175a0",
"metadata": {},
"source": [
"We will use the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) dataset which is a reading comprehension dataset, consisting of questions and answers pairs. The SquadDataModule in NeMo 2.0 provides the dataloaders for the SQuAD dataset. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fc6a132-688e-4ad3-94ae-557e57ab77cb",
"metadata": {},
"outputs": [],
"source": [
"class SquadDataModuleWithPthDataloader(llm.SquadDataModule):\n",
" \"\"\"Creates a squad dataset with a PT dataloader\"\"\"\n",
"\n",
" def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:\n",
" return DataLoader(\n",
" dataset,\n",
" num_workers=self.num_workers,\n",
" pin_memory=self.pin_memory,\n",
" persistent_workers=self.persistent_workers,\n",
" collate_fn=dataset.collate_fn,\n",
" batch_size=self.micro_batch_size,\n",
" **kwargs,\n",
" )\n",
"\n",
"\n",
"def squad(tokenizer, mbs=1, gbs=2) -> pl.LightningDataModule:\n",
" \"\"\"Instantiates a SquadDataModuleWithPthDataloader and return it\n",
"\n",
" Args:\n",
" tokenizer (AutoTokenizer): the tokenizer to use\n",
"\n",
" Returns:\n",
" pl.LightningDataModule: the dataset to train with.\n",
" \"\"\"\n",
" return SquadDataModuleWithPthDataloader(\n",
" tokenizer=tokenizer,\n",
" seq_length=512,\n",
" micro_batch_size=mbs,\n",
" global_batch_size=gbs,\n",
" num_workers=0,\n",
" dataset_kwargs={\n",
" \"sanity_check_dist_workers\": False,\n",
" \"get_attention_mask_from_fusion\": True,\n",
" },\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "a23943ee-ffa1-497d-a395-3e4767271341",
"metadata": {},
"source": [
"## Step 2: Set Parameters and Start the PEFT with a HF Model\n",
"\n",
"Now, we will set some of the important variables, including the HF model name, maximum steps, number of GPUs, etc. You can find the details of these parameters below.\n",
"- `model_name`: Pre-trained HF model or path of a HF model.\n",
"- `strategy`: Distributed training strategy such as DDP, FSDP, etc. \n",
"- `devices`: Number of GPUs to be used in the training.\n",
"- `max_steps`: Number of steps in the training.\n",
"- `wandb_project`: wandb project.\n",
"- `use_torch_jit`: Enable torch jit or not.\n",
"- `ckpt_folder`: Path for the checkpoins.\n",
"\n",
"All popular models, including Llama, GPT, Gemma, Mistral, Phi, and Qwen, are supported. After running this workflow, please select another HF model and rerun the notebook with that model. Ensure the chosen model fits within your GPU(s) memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3780a047-febb-4d97-a59a-99d8ee036332",
"metadata": {},
"outputs": [],
"source": [
"# In order to use the models like Llama, Gemma, you need to ask for permission on the HF model page and then pass the HF_TOKEN in the next cell.\n",
"# model_name = \"google/gemma-2b\" # HF model name. This can be the path of the downloaded model as well.\n",
"model_name = \"meta-llama/Llama-3.2-1B\" # HF model name. This can be the path of the downloaded model as well.\n",
"strategy = \"\" # Distributed training (e.g., \"ddp\", \"fsdp2\") requires non-interactive environment,\n",
"max_steps = 100 # Number of steps in the training loop.\n",
"accelerator = \"gpu\"\n",
"num_devices = 2 # Number of GPUs to run this notebook on.\n",
"wandb_name = None\n",
"use_torch_jit = False # torch jit can be enabled.\n",
"ckpt_folder=\"/opt/checkpoints/automodel_experiments/\" # Path for saving the checkpoint."
]
},
{
"cell_type": "markdown",
"id": "6966670b-2097-47c0-95f2-edaafab0e33f",
"metadata": {},
"source": [
"Some models have gated access. If you are using one of those models, you will need to obtain access first. Then, set your HF Token by running the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "439a3c6a-8718-4b49-acdb-e7f59db38f59",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"HF_TOKEN\"] ='<HF_TOKEN>'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will define the helper function `make_strategy` that creates the model strategy. Importantly, this method also specifies the `checkpoint_io` so that the output is saved in HF-compatible format."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def make_strategy(strategy, model, devices, num_nodes, adapter_only=False):\n",
" \"\"\"\n",
" Creates and returns a distributed training strategy based on the provided strategy type.\n",
"\n",
" Parameters:\n",
" strategy (str): The name of the strategy ('ddp', 'fsdp2', or other).\n",
" model: The model instance, which provides a method to create the HF checkpoint IO.\n",
" devices (int): Number of devices per node.\n",
" num_nodes (int): Number of compute nodes.\n",
" adapter_only (bool, optional): Whether to save only adapter-related parameters in the checkpoint. Default is False.\n",
"\n",
" Returns:\n",
" A PyTorch Lightning or custom distributed training strategy.\n",
" \"\"\"\n",
"\n",
" if strategy == 'ddp': # Distributed Data Parallel (DDP) strategy\n",
" return pl.strategies.DDPStrategy(\n",
" checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only),\n",
" )\n",
"\n",
" elif strategy == 'fsdp2': # Fully Sharded Data Parallel (FSDP) v2 strategy\n",
" return nl.FSDP2Strategy(\n",
" data_parallel_size=devices * num_nodes, # Defines total data parallel size\n",
" tensor_parallel_size=1, # No tensor parallelism\n",
" checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only),\n",
" )\n",
"\n",
" else: # Default to single device strategy (useful for debugging or single-GPU training)\n",
" return pl.strategies.SingleDeviceStrategy(\n",
" device='cuda:0', # Uses the first available CUDA device\n",
" checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only),\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "7cd65e5e-93fa-4ea0-b89d-2f48431b725c",
"metadata": {},
"source": [
"After setting some parameters, we can start the PEFT training workflow. Although the PEFT workflow with HF models/checkpoints differs slightly from workflows with NeMo models/checkpoints, we still use the same NeMo 2.0 API. The main difference is the model we pass into the fine-tune API.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3578630-05b7-4a8c-8b5d-a7d9e847f17b",
"metadata": {},
"outputs": [],
"source": [
"wandb = WandbLogger(\n",
" project=\"nemo_automodel\",\n",
" name=wandb_name,\n",
") if wandb_name is not None else None\n",
"\n",
"callbacks = []\n",
"if use_torch_jit:\n",
" jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False)\n",
" callbacks = [JitTransform(jit_config)]\n",
"\n",
"\n",
"model = llm.HFAutoModelForCausalLM(model_name=model_name)\n",
"\n",
"# Create model strategy\n",
"strategy = make_strategy(strategy, model, num_devices, 1)\n",
"\n",
"trainer = nl.Trainer(\n",
" devices=1,\n",
" max_steps=max_steps,\n",
" accelerator=\"gpu\",\n",
" strategy=strategy,\n",
" log_every_n_steps=1,\n",
" limit_val_batches=0.0,\n",
" num_sanity_val_steps=0,\n",
" accumulate_grad_batches=1,\n",
" gradient_clip_val=1.0,\n",
" use_distributed_sampler=False,\n",
" logger=wandb,\n",
" callbacks=callbacks,\n",
" precision=\"bf16\",\n",
")\n",
"\n",
"llm.api.finetune(\n",
" model=model,\n",
" data=squad(llm.HFAutoModelForCausalLM.configure_tokenizer(model_name), gbs=1),\n",
" trainer=trainer,\n",
" optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),\n",
" log=NeMoLogger(log_dir=ckpt_folder, use_datetime_version=False),\n",
" peft=llm.peft.LoRA(\n",
" target_modules=['*_proj'],\n",
" dim=8,\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Note: switching between SFT <> PEFT is frictionless\n",
"\n",
"While the previous code section shows how to run SFT on a model using one of more GPUs, it is very easy to switch to PEFT.\n",
"To use PEFT instead of SFT, all we have to do is define a peft adapter to the `peft` parameter of the `llm.api.finetune` call above:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Use the following if you want to SFT instead.\n",
"llm.api.finetune(\n",
" model=model,\n",
" data=squad(llm.HFAutoModelForCausalLM.configure_tokenizer(model_name), gbs=1),\n",
" trainer=trainer,\n",
" optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),\n",
" peft=None, # no adapter specified -> will run SFT\n",
" log=None,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "67e6e4d4-8e0c-4507-b386-22c3d63097c1",
"metadata": {},
"source": [
"## Step 3: Generate Output with the HF Pipeline\n",
"\n",
"Once the PEFT training is completed, you can generate output using HF's APIs to see the quality of the outputs.\n",
"The PEFT checkpoint will be saved in a folder defined by the `ckpt_folder` variable. After the first run, the new checkpoint will be saved in a folder named `default/checkpoints/efault--None=0.0000-epoch=0-step=0/hf_adapter/`. If you run this notebook, you will see multiple checkpoints in the same location.\n",
"\n",
"Note that the `hf_adapter` directory, contains the PEFT adapter in HF-compatible format, therefore allowing users to load the model on HF's transformer library or in NeMo.\n",
"\n",
"Let's run the finetuned model to generate some text."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a118868-8c6e-44ad-9b2d-3be3994a093b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"from pathlib import Path\n",
"from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\n",
"from peft import PeftModel\n",
"\n",
"model_name = \"meta-llama/Llama-3.2-1B\"\n",
"ckpt_folder=\"/opt/checkpoints/automodel_experiments/\"\n",
"peft_checkpoint = Path(ckpt_folder) / f\"default/checkpoints/default--None=0.0000-epoch=0-step={max_steps}/hf_adapter/\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"model.load_adapter(peft_checkpoint, adapter_name=\"adapter_1\")\n",
"model.set_adapter(\"adapter_1\")\n",
"\n",
"pipe = pipeline(\n",
" \"text-generation\",\n",
" model=model,\n",
" tokenizer=AutoTokenizer.from_pretrained(model_name),\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" device=0,\n",
")\n",
"\n",
"pipe(\"The key to life is\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7a3e927-487e-4450-a559-e9e1a5623b66",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "nemo3.10",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|