{ "cells": [ { "cell_type": "markdown", "id": "8a904d03-8c1a-4eb4-bd77-a26ec985d992", "metadata": {}, "source": [ "## Load the Dataset" ] }, { "cell_type": "code", "execution_count": 1, "id": "f398bea9-d201-43ae-95e8-d4fad1de651c", "metadata": {}, "outputs": [], "source": [ "# Opening modified dataset uploaded to HF\n", "from datasets import load_dataset\n", "\n", "dataset_name = \"RaviNaik/oasst1-chatml\"\n", "dataset = load_dataset(dataset_name, split=\"train\")" ] }, { "cell_type": "markdown", "id": "b74449fb-9f01-4c96-8a7e-6661180be1f3", "metadata": {}, "source": [ "## Load the Model, Tokenizer and configure bnb" ] }, { "cell_type": "code", "execution_count": 2, "id": "ab5a26da-b9e0-424f-b482-034c10f049ce", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer" ] }, { "cell_type": "code", "execution_count": 3, "id": "8f5d677d-082a-4d96-9fab-4d66981ecc27", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "88aa19589a2b46f9888c214db3867598", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_name = \"microsoft/phi-2\"\n", "\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.float16,\n", ")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " quantization_config=bnb_config,\n", " trust_remote_code=True,\n", " device_map=\"cuda:0\"\n", ")\n", "model.config.use_cache = False" ] }, { "cell_type": "code", "execution_count": 4, "id": "0d649bc1-69d0-4683-b3e7-b999f4a52ce7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, device_map=\"cuda:0\")\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "markdown", "id": "2304e9e1-fe0a-48c9-8eaa-605af8d93ea1", "metadata": {}, "source": [ "## Display Model Layers" ] }, { "cell_type": "code", "execution_count": 5, "id": "0d572c85-fc5b-407d-bdd9-667c3cdb74bd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PhiForCausalLM(\n", " (transformer): PhiModel(\n", " (embd): Embedding(\n", " (wte): Embedding(51200, 2560)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (h): ModuleList(\n", " (0-31): 32 x ParallelBlock(\n", " (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n", " (resid_dropout): Dropout(p=0.1, inplace=False)\n", " (mixer): MHA(\n", " (rotary_emb): RotaryEmbedding()\n", " (Wqkv): Linear4bit(in_features=2560, out_features=7680, bias=True)\n", " (out_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)\n", " (inner_attn): SelfAttention(\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (inner_cross_attn): CrossAttention(\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (mlp): MLP(\n", " (fc1): Linear4bit(in_features=2560, out_features=10240, bias=True)\n", " (fc2): Linear4bit(in_features=10240, out_features=2560, bias=True)\n", " (act): NewGELUActivation()\n", " )\n", " )\n", " )\n", " )\n", " (lm_head): CausalLMHead(\n", " (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n", " (linear): Linear(in_features=2560, out_features=51200, bias=True)\n", " )\n", " (loss): CausalLMLoss(\n", " (loss_fct): CrossEntropyLoss()\n", " )\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "markdown", "id": "7d628159-3f0e-43eb-bf7b-12129f54e0df", "metadata": {}, "source": [ "## Configure LoRA for finetuning" ] }, { "cell_type": "code", "execution_count": 6, "id": "45d0e0e8-a8d8-4a05-9c7d-4f6217b57310", "metadata": {}, "outputs": [], "source": [ "from peft import LoraConfig" ] }, { "cell_type": "code", "execution_count": 7, "id": "2cb2780d-609d-4fba-902d-a2526863b02c", "metadata": {}, "outputs": [], "source": [ "lora_alpha = 16\n", "lora_dropout = 0.1\n", "lora_r = 64\n", "\n", "peft_config = LoraConfig(\n", " lora_alpha=lora_alpha,\n", " lora_dropout=lora_dropout,\n", " r=lora_r,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " target_modules=[\n", " \"Wqkv\",\n", " \"out_proj\",\n", " \"fc1\",\n", " \"fc2\",\n", " ]\n", ")" ] }, { "cell_type": "markdown", "id": "f1ed1b63-2263-4574-9497-180acd38ec14", "metadata": {}, "source": [ "## Configure Training Params" ] }, { "cell_type": "code", "execution_count": 8, "id": "f5074cda-91b9-4077-9d9f-b74d34778767", "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments" ] }, { "cell_type": "code", "execution_count": 9, "id": "a793135a-a797-42b1-935d-fb969772a91f", "metadata": {}, "outputs": [], "source": [ "output_dir = \"./results\"\n", "per_device_train_batch_size = 4\n", "gradient_accumulation_steps = 4\n", "optim = \"paged_adamw_32bit\"\n", "save_steps = 100\n", "logging_steps = 10\n", "learning_rate = 2e-4\n", "max_grad_norm = 0.3\n", "max_steps = 500\n", "warmup_ratio = 0.03\n", "lr_scheduler_type = \"constant\"\n", "\n", "training_arguments = TrainingArguments(\n", " output_dir=output_dir,\n", " per_device_train_batch_size=per_device_train_batch_size,\n", " gradient_accumulation_steps=gradient_accumulation_steps,\n", " optim=optim,\n", " save_steps=save_steps,\n", " logging_steps=logging_steps,\n", " learning_rate=learning_rate,\n", " fp16=True,\n", " max_grad_norm=max_grad_norm,\n", " max_steps=max_steps,\n", " warmup_ratio=warmup_ratio,\n", " group_by_length=True,\n", " lr_scheduler_type=lr_scheduler_type,\n", " gradient_checkpointing=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "606d810a-8eaa-42df-9804-9f6bd421dee6", "metadata": {}, "outputs": [], "source": [ "from trl import SFTTrainer" ] }, { "cell_type": "code", "execution_count": 11, "id": "d554e892-5966-4cb3-9ef4-ede9a14a50e6", "metadata": {}, "outputs": [], "source": [ "max_seq_length = 256\n", "\n", "trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=dataset,\n", " peft_config=peft_config,\n", " dataset_text_field=\"text\",\n", " max_seq_length=max_seq_length,\n", " tokenizer=tokenizer,\n", " args=training_arguments,\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "id": "129c1914-447d-4e95-9346-fe9eb892eb12", "metadata": {}, "outputs": [], "source": [ "for name, module in trainer.model.named_modules():\n", " if \"norm\" in name:\n", " module = module.to(torch.float32).to(\"cuda:0\")" ] }, { "cell_type": "markdown", "id": "a9f421a1-073d-4042-a092-650d7f6d27ec", "metadata": {}, "source": [ "## Begin Training" ] }, { "cell_type": "code", "execution_count": 13, "id": "ca7336ee-fc7d-479e-a384-24737ab74007", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a CodeGenTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n", "/home/ravi.naik/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
10 | \n", "1.720800 | \n", "
20 | \n", "1.548300 | \n", "
30 | \n", "1.550700 | \n", "
40 | \n", "1.543500 | \n", "
50 | \n", "1.506600 | \n", "
60 | \n", "1.501400 | \n", "
70 | \n", "1.559600 | \n", "
80 | \n", "1.552900 | \n", "
90 | \n", "1.506300 | \n", "
100 | \n", "1.471700 | \n", "
110 | \n", "1.484100 | \n", "
120 | \n", "1.498100 | \n", "
130 | \n", "1.517600 | \n", "
140 | \n", "1.481100 | \n", "
150 | \n", "1.485100 | \n", "
160 | \n", "1.516800 | \n", "
170 | \n", "1.505000 | \n", "
180 | \n", "1.484600 | \n", "
190 | \n", "1.495200 | \n", "
200 | \n", "1.447100 | \n", "
210 | \n", "1.520700 | \n", "
220 | \n", "1.444500 | \n", "
230 | \n", "1.464800 | \n", "
240 | \n", "1.480200 | \n", "
250 | \n", "1.444100 | \n", "
260 | \n", "1.543900 | \n", "
270 | \n", "1.512700 | \n", "
280 | \n", "1.441300 | \n", "
290 | \n", "1.502200 | \n", "
300 | \n", "1.476900 | \n", "
310 | \n", "1.478200 | \n", "
320 | \n", "1.481000 | \n", "
330 | \n", "1.433600 | \n", "
340 | \n", "1.404000 | \n", "
350 | \n", "1.401000 | \n", "
360 | \n", "1.424400 | \n", "
370 | \n", "1.429100 | \n", "
380 | \n", "1.388700 | \n", "
390 | \n", "1.402600 | \n", "
400 | \n", "1.417900 | \n", "
410 | \n", "1.358200 | \n", "
420 | \n", "1.460700 | \n", "
430 | \n", "1.417800 | \n", "
440 | \n", "1.447300 | \n", "
450 | \n", "1.429200 | \n", "
460 | \n", "1.388100 | \n", "
470 | \n", "1.433200 | \n", "
480 | \n", "1.431600 | \n", "
490 | \n", "1.491200 | \n", "
500 | \n", "1.406600 | \n", "
"
],
"text/plain": [
"