File size: 5,093 Bytes
6ecf14b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import torch\n",
    "import torch.multiprocessing as mp\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, TrainingArguments\n",
    "from datasets import load_dataset\n",
    "from trl import SFTTrainer\n",
    "from peft import LoraConfig\n",
    "\n",
    "def init_distributed(rank, world_size):\n",
    "    os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
    "    os.environ[\"MASTER_PORT\"] = \"12345\"\n",
    "    if rank == 0:\n",
    "        print(\"Initializing distributed process group...\")\n",
    "    torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank)\n",
    "\n",
    "def cleanup_distributed():\n",
    "    torch.distributed.destroy_process_group()\n",
    "\n",
    "def main_worker(rank, world_size):\n",
    "    init_distributed(rank, world_size)\n",
    "\n",
    "    # Move the finetune() function here\n",
    "    # Load the dataset\n",
    "    dataset_name = \"ruslanmv/ai-medical-dataset\"\n",
    "    dataset = load_dataset(dataset_name, split=\"train\")\n",
    "    # Select the first 1M rows of the dataset\n",
    "    dataset = dataset.select(range(100))\n",
    "    # Load the model + tokenizer\n",
    "    model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    bnb_config = BitsAndBytesConfig(\n",
    "      load_in_4bit=True,\n",
    "      bnb_4bit_quant_type=\"nf4\",\n",
    "      bnb_4bit_compute_dtype=torch.float16,\n",
    "    )\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name,\n",
    "        quantization_config=bnb_config,\n",
    "        trust_remote_code=True,\n",
    "        use_cache=False,\n",
    "    )\n",
    "    # Replace the DDP wrapping part with the following lines\n",
    "    model = model.to(rank)\n",
    "    model = DDP(model, device_ids=[rank], output_device=rank)\n",
    "\n",
    "    # PEFT config\n",
    "    lora_alpha = 16\n",
    "    lora_dropout = 0.1\n",
    "    lora_r = 32  # 64\n",
    "    peft_config = LoraConfig(\n",
    "      lora_alpha=lora_alpha,\n",
    "      lora_dropout=lora_dropout,\n",
    "      r=lora_r,\n",
    "      bias=\"none\",\n",
    "      task_type=\"CAUSAL_LM\",\n",
    "      target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
    "      modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n",
    "    )\n",
    "    # Args\n",
    "    max_seq_length = 512\n",
    "    output_dir = \"./results\"\n",
    "    per_device_train_batch_size = 2  # reduced batch size to avoid OOM\n",
    "    gradient_accumulation_steps = 2\n",
    "    optim = \"adamw_torch\"\n",
    "    save_steps = 10\n",
    "    logging_steps = 1\n",
    "    learning_rate = 2e-4\n",
    "    max_grad_norm = 0.3\n",
    "    max_steps = 1  # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.\n",
    "    warmup_ratio = 0.1\n",
    "    lr_scheduler_type = \"cosine\"\n",
    "\n",
    "    training_arguments = TrainingArguments(\n",
    "      output_dir=output_dir,\n",
    "      per_device_train_batch_size=per_device_train_batch_size,\n",
    "      gradient_accumulation_steps=gradient_accumulation_steps,\n",
    "      optim=optim,\n",
    "      save_steps=save_steps,\n",
    "      logging_steps=logging_steps,\n",
    "      learning_rate=learning_rate,\n",
    "      fp16=True,\n",
    "      max_grad_norm=max_grad_norm,\n",
    "      max_steps=max_steps,\n",
    "      warmup_ratio=warmup_ratio,\n",
    "      group_by_length=True,\n",
    "      lr_scheduler_type=lr_scheduler_type,\n",
    "      gradient_checkpointing=True,  # gradient checkpointing\n",
    "      report_to=\"wandb\",\n",
    "    )\n",
    "    # Trainer\n",
    "    trainer = SFTTrainer(\n",
    "      model=model,\n",
    "      train_dataset=dataset,\n",
    "      peft_config=peft_config,\n",
    "      dataset_text_field=\"context\",\n",
    "      max_seq_length=max_seq_length,\n",
    "      tokenizer=tokenizer,\n",
    "      args=training_arguments,\n",
    "    )\n",
    "    # Train :)\n",
    "    trainer.train()\n",
    "    cleanup_distributed()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    world_size = torch.cuda.device_count()\n",
    "\n",
    "    processes = []\n",
    "    for rank in range(world_size):\n",
    "        p = mp.Process(target=main_worker, args=(rank, world_size))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}