File size: 24,209 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 |
{
"cells": [
{
"cell_type": "markdown",
"id": "7c9d27250020aba6",
"metadata": {},
"source": [
"# Finetuning Llama 3.2 Model into Embedding Model\n",
"\n",
"## Goal\n",
"\n",
"While LLaMA 3.2 is a powerful large language model (LLM) pre-trained on diverse datasets, its application to specific downstream tasks—such as semantic search, document retrieval, or natural language understanding—requires adapting the model to effectively generate dense vector representations (embeddings). In this tutorial, we will demonstrate how to finetune this model and convert it into a state-of-the-art embedding model for retrieval-augmented generation (RAG) tasks.\n",
"\n",
"The key architectural change involves modifying the LLaMA model to optimize its performance in generating embeddings by replacing causal attention with bidirectional attention. This change enables the decoder-only model to create embeddings that are contextually relevant, semantically rich, and capable of improving the efficiency and accuracy of tasks like information retrieval, clustering, and text classification.\n",
"\n",
"Our primary goals for this tutorial are as follows:\n",
"\n",
" * Demonstrate the ease of automatically converting the model with essential architectural changes for embedding model training\n",
" * Improve the model's performance and accuracy in generating dense vector representations (embeddings)\n",
" * Provide guidelines for finetuning embedding models, including hyperparameter choices.\n"
]
},
{
"cell_type": "markdown",
"id": "45a0f1b96249be78",
"metadata": {},
"source": [
"# NeMo Tools and Resources\n",
"\n",
"* [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html)\n",
"\n",
"# Software Requirements\n",
"\n",
"* Access to latest NeMo Framework NGC Containers\n",
"\n",
"\n",
"# Hardware Requirements\n",
"\n",
"* This playbook has been tested on the following hardware: Single A6000, Single H100, 2xA6000, 8xH100. It can be scaled to multiple GPUs as well as multiple nodes by modifying the appropriate parameters.\n"
]
},
{
"cell_type": "markdown",
"id": "87846682e01e1a50",
"metadata": {},
"source": [
"#### Launch the NeMo Framework container as follows: \n",
"\n",
"Depending on the number of gpus, `--gpus` might need to adjust accordingly:\n",
"```\n",
"docker run -it -p 8080:8080 -p 8088:8088 --rm --gpus '\"device=0,1\"' --ipc=host --network host -v $(pwd):/workspace nvcr.io/nvidia/nemo:25.02\n",
"```\n",
"\n",
"#### Launch Jupyter Notebook as follows: \n",
"```\n",
"jupyter notebook --allow-root --ip 0.0.0.0 --port 8088 --no-browser --NotebookApp.token=''\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "a9283ea8354238d2",
"metadata": {},
"source": [
"# Data\n",
"\n",
"In this playbook, we will use the [Specter](https://huggingface.co/datasets/sentence-transformers/specter) datasets. \n",
"\n",
"The Specter dataset is a large collection of scientific papers and their abstracts, specifically designed for tasks like document classification, citation recommendation, and information retrieval. We will use the title-related-unrelated triplets from the dataset for finetuning. An example entry from this dataset is shown below:\n",
"\n",
"```\n",
"Anchor: Millimeter-wave CMOS digital controlled artificial dielectric differential mode transmission lines for reconfigurable ICs\n",
"Positive: CMP network-on-chip overlaid with multi-band RF-interconnect\n",
"Negative: Route packets, not wires: on-chip interconnection networks\n",
"```\n",
"\n",
"We will also demonstrate how to prepare and finetune the model using your own customized dataset"
]
},
{
"cell_type": "markdown",
"id": "7a44b99fbd6df0c7",
"metadata": {},
"source": [
"# Notebook Outline\n",
"\n",
"* Step 1: Auto download and convert the original Llama 3.2 model to NeMo2\n",
"* Step 2: (Optional) Prepare your own dataset\n",
"* Step 3: Finetuning Llama 3.2 model to a embedding model\n",
"* Step 4: (Optional) Convert the model to HuggingFace Transformer format"
]
},
{
"cell_type": "markdown",
"id": "7cfc90d2a48c52fa",
"metadata": {},
"source": [
"# Step 1: Auto Download and Convert the Original Llama 3.2 Model to NeMo2\n",
"\n",
"Llama 3.2 model can be automatically downloaded and converted to NeMo2 format with the following script:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc17124fe0d22afd",
"metadata": {},
"outputs": [],
"source": [
"%%writefile convert_llama_embedding.py\n",
"\n",
"from nemo.collections import llm\n",
"\n",
"if __name__ == '__main__':\n",
" llm.import_ckpt(\n",
" model=llm.LlamaEmbeddingModel(config=llm.Llama32EmbeddingConfig1B()),\n",
" source=\"hf://meta-llama/Llama-3.2-1B\",\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34c0353572a0883f",
"metadata": {},
"outputs": [],
"source": [
"!huggingface-cli login --token <YOUR HF TOKEN>\n",
"!torchrun convert_llama_embedding.py"
]
},
{
"cell_type": "markdown",
"id": "d06bac2161dcf78f",
"metadata": {},
"source": [
"The above script \n",
"- Downloads the LLaMA 3.2 1B model from Hugging Face (if not already downloaded).\n",
"- Automatically converts it into the NeMo format.\n",
"- Replaces all attention blocks with bidirectional attention, which plays a critical role in capturing the full context of a word or token in a sequence, as demonstrated by models like BERT (Bidirectional Encoder Representations from Transformers).\n",
"\n",
"Note:\n",
"- The script can only run in a Python environment, not in a Jupyter notebook.\n",
"- You need to have access to ```Llama-3.2-1B``` [repo on Hugging Face](https://huggingface.co/meta-llama/Llama-3.2-1B/tree/main). Log in using a token by `huggingface-cli login --token <YOUR HF TOKEN>`\n",
"\n",
"The conversion will create a ```Llama-3.2-1B``` folder in the default ```$NEMO_HOME``` directory. \n",
"```$NEMO_HOME``` centralizes and stores all models and datasets used for NeMo training. By default `$NEMO_HOME stores to ```~/.cache/nemo```.\n",
"This folder will be used in Step 3 for fine-tuning the model to generate the embedding model."
]
},
{
"cell_type": "markdown",
"id": "313d0e57e0bfbb77",
"metadata": {},
"source": [
"# Step 2: (Optional) Prepare Your own Dataset\n",
"\n",
"If you want to use your own dataset for finetuning, prepare it as a single JSON file containing a list of JSON objects. Each object should include the following three fields:\n",
"\n",
"* question: query or a prompt that the model needs to process and understand. In embedding-based models (especially in tasks like question answering (QA) or information retrieval), the \"question\" typically represents a user's query that seeks relevant information from a database or corpus of documents.\n",
"* pos_doc (Positive Document): \"pos_doc\" refers to a ```list``` of positive document — a document that is relevant to the given question. In other words, this document contains information that answers or is closely related to the question. \n",
"* neg_doc (Negative Document): \"neg_doc\" refers to a ```list``` of negative document — a document that is irrelevant or less relevant to the question. This document either does not contain any useful information or contains information that does not answer or is only tangentially related to the question.\n",
"\n",
"You can create the datamodule using NeMo's [CustomRetrievalDataModule](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/gpt/data/retrieval.py#L31) to create the datamodule used for finetuning. ``CustomRetrievalDataModule`` takes one JSON `data_root` and performs train/val/test split, and it stores the processed dataset file to `$NEMO_HOME/datasets/<dataset_identifier>`. You can customize the ``query_key``, `pos_doc_key`, and `neg_doc_key` to match your own JSON dataset.\n",
"\n",
"You can use [Specter](https://huggingface.co/datasets/sentence-transformers/specter/viewer/triplet) dataset as an example to better understand the question/pos_doc/neg_doc requires for embedding model training to prepare your own dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cde5aac56242d3ab",
"metadata": {},
"outputs": [],
"source": [
"# JSON File contains a list of JSON objects with question, pos_doc, and neg_doc\n",
"[\n",
" {\n",
" \"question\": \"What do floor indicaters often consist of?\",\n",
" \"pos_doc\": [\"In addition to the call buttons, elevators usually have floor indicators (often illuminated by LED) and direction lanterns...\"],\n",
" \"neg_doc\": [\"neg doc1\", \"neg doc2\", \"neg doc3\", ...]\n",
" },\n",
" {\n",
" \"question\": \"An welchem Tag wurde das St. Helena Radio abgeschaltet?\",\n",
" \"pos_doc\": [\"Radio St. Helena, die am Weihnachten 1967 begann, bot einen lokalen Radio-Service, der eine Reichweite von ca. 100 km (62 mi)...\"],\n",
" \"neg_doc\": [\"neg doc1\", \"neg doc2\", \"neg doc3\", ...]\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "bdea5dd972cbde5b",
"metadata": {},
"source": [
"# Step 3: Finetuning Llama 3.2 model into an embedding model\n",
"\n",
"For this step we use the NeMo2 predefined recipe. We will modify the recipe and showcase how you can adjust it to fit your own workflow. Typically, this involves changing the dataset, learning rate scheduler, and default parallelism etc.\n",
"\n",
"First we define the recipe and executor for using NeMo2 as follows:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6443580805b02989",
"metadata": {},
"outputs": [],
"source": [
"import nemo_run as run\n",
"from nemo.collections import llm\n",
"\n",
"def configure_recipe(nodes: int = 1, gpus_per_node: int = 1):\n",
" recipe = llm.recipes.llama_embedding_1b.finetune_recipe(\n",
" name=\"llama32_1b_embedding_finetuning\",\n",
" num_nodes=nodes,\n",
" num_gpus_per_node=gpus_per_node,\n",
" peft_scheme=None,\n",
" )\n",
" return recipe\n",
"\n",
"def local_executor_torchrun(devices: int = 1) -> run.LocalExecutor:\n",
" executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\")\n",
" return executor\n",
"\n",
"# Instantiate the recipe\n",
"# Make sure you set the gpus_per_node as expected\n",
"recipe = configure_recipe(gpus_per_node=1) "
]
},
{
"cell_type": "markdown",
"id": "697d71ea5ad76d52",
"metadata": {},
"source": [
"You can learn more about NeMo Executor [here](https://github.com/NVIDIA/NeMo-Run/blob/main/docs/source/guides/execution.md).\n",
"\n",
"Let's understand a bit in depth of how the `recipe` works here:\n",
"\n",
"`recipe` has few components: \n",
"- model: the model to train\n",
"- data: the datamodule to train,\n",
"- trainer: Pytorch Lightning trainer that defines training strategy \n",
"- log: logger for logging the training progress\n",
"- optim: optimizer to use for training,\n",
"- resume: resume connector to resume from a checkpoint prior training\n",
"\n",
"The default `recipe` initializes all the essential components required for finetuning the LLaMA 3.2 1B embedding model.\n",
"\n",
"An important point to remember is that at this stage, you are only configuring your task; the underlying code is not executed yet.\n",
"This allows you to modify the configuration to fit your own custom training workflow.\n",
"\n",
"In this task, `model` is initialized to `LlamaEmbeddingModel`, which, by default, uses ranking loss to learn the embedding.\n",
"Ranking loss penalizes the model based on how well it ranks the correct documents relative to incorrect ones. By default, the loss function only calculates the hard negatives provided by the dataset. However, you can treat examples within the batch as additional negatives by modifying the model's [config](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/gpt/model/llama_embedding.py#L117):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b6cb18418767f2",
"metadata": {},
"outputs": [],
"source": [
"# (Optional) Change the loss objective to also take consideration of in-batch negatives:\n",
"recipe.model.config.in_batch_negatives = True"
]
},
{
"cell_type": "markdown",
"id": "38765b20d36b550a",
"metadata": {},
"source": [
"For this particular task of finetuning an embedding model, we want to adjust the recipe with smaller learning rate, and this can be achieved by:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12f3aa89e02a3dad",
"metadata": {},
"outputs": [],
"source": [
"# (Optional) Modify the LR scheduler\n",
"recipe.optim.config.lr = 5e-6\n",
"recipe.optim.lr_scheduler.min_lr = 5e-7"
]
},
{
"cell_type": "markdown",
"id": "f0a328e84a81f848",
"metadata": {},
"source": [
"Optionally, you can also change the parallelism settings, typically for the small 1B embedding model, tensor parallelism/ pipeline parallelism/ context parallelism are not required to hold the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "249c19a39a05b914",
"metadata": {},
"outputs": [],
"source": [
"# (Optional) Modify the TP/PP/CP settings\n",
"# For small model (that fits well in single GPU), using TP1PP1CP1 typically yields the best performance.\n",
"recipe.trainer.strategy.tensor_model_parallel_size = 1\n",
"recipe.trainer.strategy.pipeline_model_parallel_size = 1\n",
"recipe.trainer.strategy.context_parallel_size = 1"
]
},
{
"cell_type": "markdown",
"id": "d0ddb50cf59d4091",
"metadata": {},
"source": [
"You can also configure trainer-related setting such as decreasing the default training steps:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65806f6c8372b750",
"metadata": {},
"outputs": [],
"source": [
"# For the purpose of the demo, we limit the max_steps to 100\n",
"recipe.trainer.max_steps = 100\n",
"recipe.trainer.val_check_interval = 10\n",
"recipe.trainer.limit_val_batches = 5\n",
"recipe.log.use_datetime_version = False"
]
},
{
"cell_type": "markdown",
"id": "4cc642814056a1ce",
"metadata": {},
"source": [
"Datamodule is swappable as well. The default `recipe` uses `SpecterDataModule`, if you want to use your customized dataset as described above, replace the data module as followed:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b170f3d761809d43",
"metadata": {},
"outputs": [],
"source": [
"def get_customize_dataloader(\n",
" data_root='path/to/json/data',\n",
" dataset_identifier='id_to_store_dataset',\n",
" seq_length=512,\n",
" micro_batch_size=16,\n",
" global_batch_size=64,\n",
" tokenizer=None,\n",
" num_workers=16,\n",
"):\n",
" return run.Config(\n",
" llm.CustomRetrievalDataModule,\n",
" data_root=data_root,\n",
" dataset_identifier=dataset_identifier,\n",
" seq_length=seq_length,\n",
" micro_batch_size=micro_batch_size,\n",
" global_batch_size=global_batch_size,\n",
" tokenizer=tokenizer,\n",
" num_workers=num_workers,\n",
" )\n",
"# To override the datamodule, uncomment below\n",
"# recipe.data = dataloader"
]
},
{
"cell_type": "markdown",
"id": "674f74e08d38a532",
"metadata": {},
"source": [
"After configure the training procedure properly, we can run the training by instantiate the `executor` and use `nemorun` to start the training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10dc6d15a6635040",
"metadata": {},
"outputs": [],
"source": [
"executor = local_executor_torchrun(devices=recipe.trainer.devices)\n",
"run.run(recipe, executor=executor)"
]
},
{
"cell_type": "markdown",
"id": "da4d579461e81304",
"metadata": {},
"source": [
"# Step 4: Evaluate the Llama Embedding Model\n",
"\n",
"After successfully training a checkpoint, we should evaluate the effectiveness of the trained model. This can be achieved by the following script using NeMo Framework for inference:\n",
"\n",
"(The script assumes the trained checkpoint is saved as `nemo_experiments/llama32_1b_embedding_finetuning/checkpoints/model_name=0--val_loss=0.47-step=99-consumed_samples=6400.0-last`)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "147f98459257601",
"metadata": {},
"outputs": [],
"source": [
"%%writefile nemo_inference.py\n",
"from pathlib import Path\n",
"\n",
"import torch\n",
"import torch.distributed\n",
"\n",
"import nemo.lightning as nl\n",
"from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model\n",
"from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer\n",
"from nemo.lightning import io\n",
"from nemo.utils.exp_manager import TimingCallback\n",
"from nemo.lightning.ckpt_utils import ckpt_to_context_subdir\n",
"\n",
"if __name__ == \"__main__\":\n",
" model_nemo_path = 'nemo_experiments/llama32_1b_embedding_finetuning/checkpoints/model_name=0--val_loss=0.47-step=99-consumed_samples=6400.0-last'\n",
"\n",
" # Setup Trainer\n",
" strategy = nl.MegatronStrategy(\n",
" tensor_model_parallel_size=1,\n",
" pipeline_model_parallel_size=1,\n",
" context_parallel_size=1,\n",
" sequence_parallel=False,\n",
" setup_optimizers=False,\n",
" store_optimizer_states=False,\n",
" )\n",
"\n",
" trainer = nl.Trainer(\n",
" accelerator=\"gpu\",\n",
" devices=1,\n",
" num_nodes=1,\n",
" strategy=strategy,\n",
" plugins=nl.MegatronMixedPrecision(\n",
" precision=\"32\",\n",
" params_dtype=torch.float32,\n",
" pipeline_dtype=torch.bfloat16,\n",
" autocast_enabled=False,\n",
" grad_reduce_in_fp32=False,\n",
" ),\n",
" callbacks=[TimingCallback()]\n",
" )\n",
" queries = ['how much protein should a female eat',\n",
" 'summit define'\n",
" ]\n",
" passages = [\n",
" \"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.\",\n",
" \"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.\"\n",
" ]\n",
"\n",
" # Each input text should start with \"query: \" or \"passage: \".\n",
" # For tasks other than retrieval, you can simply use the \"query: \" prefix.\n",
" prefix_query = \"query: \"\n",
" prefix_passage = \"passage: \"\n",
" queries = [f\"{prefix_query} {query}\" for query in queries]\n",
" passages = [f\"{prefix_passage} {passage}\" for passage in passages]\n",
"\n",
" # Resume Model\n",
" path = Path(model_nemo_path)\n",
" model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath=\"model\")\n",
" _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)\n",
" tokenizer = get_nmt_tokenizer(library='huggingface', model_name=str(ckpt_to_context_subdir(path) / 'nemo_tokenizer'))\n",
"\n",
" def preprocess_input(text_inputs, add_bos=False, add_eos=False):\n",
" # construct input\n",
" bos = tokenizer.bos\n",
" eos = tokenizer.eos\n",
" pad = tokenizer.pad_id if tokenizer.pad_id else tokenizer.eos\n",
" encoded_input = [tokenizer.text_to_ids(s) for s in text_inputs]\n",
" if add_bos:\n",
" encoded_input = [[bos] + s for s in encoded_input]\n",
" if add_eos:\n",
" encoded_input = [ s + [eos] for s in encoded_input]\n",
" max_length = max([len(s) for s in encoded_input])\n",
" # pad to max_length, and mask\n",
" attention_mask = torch.stack([torch.cat((torch.ones(len(s), dtype=torch.long), torch.zeros(max_length - len(s), dtype=torch.long))) for s in encoded_input])\n",
" encoded_input = torch.Tensor([s + [pad] * (max_length - len(s)) for s in encoded_input]).to(dtype=torch.long)\n",
" position_ids = torch.arange(0, max_length, dtype=torch.long, device=encoded_input.device)\n",
" inputs = {\n",
" \"input_ids\": encoded_input.to(device=model.device),\n",
" \"attention_mask\": attention_mask.to(device=model.device),\n",
" \"position_ids\": position_ids.to(device=model.device),\n",
" }\n",
" return inputs\n",
" print('model setup correctly')\n",
" model.eval()\n",
" with torch.no_grad():\n",
" embeddings_queries = model.encode(**preprocess_input(queries))\n",
" embeddings_passages = model.encode(**preprocess_input(passages))\n",
"\n",
" scores = (embeddings_queries @ embeddings_passages.T)\n",
" print(scores.tolist())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2233bbd-b9c3-424f-993a-a4e24e3311f8",
"metadata": {},
"outputs": [],
"source": [
"!torchrun nemo_inference.py"
]
},
{
"cell_type": "markdown",
"id": "7cde40c4e2c6e9ea",
"metadata": {},
"source": [
"# Step 5: (Optional) Convert the Model to HuggingFace Transformer format\n",
"\n",
"Use `llm.export_ckpt` to automatically convert the finetuned model into HuggingFace model. Similar to converting Hugging Face model to NeMo, you will need to run this in Python script as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cee4c3f148c52c6e",
"metadata": {},
"outputs": [],
"source": [
"%%writefile convert_to_hf.py\n",
"from pathlib import Path\n",
"from nemo.collections import llm\n",
"if __name__ == '__main__':\n",
" llm.export_ckpt(\n",
" path = Path('nemo_experiments/llama32_1b_embedding_finetuning/checkpoints/model_name=0--val_loss=0.47-step=99-consumed_samples=6400.0-last'),\n",
" target='hf',\n",
" output_path=Path('path/to/converted/hf/ckpt'),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4f29e180e22ebef",
"metadata": {},
"outputs": [],
"source": [
"!torchrun convert_to_hf.py"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|