{"cells":[{"cell_type":"markdown","id":"7a849d13","metadata":{"papermill":{"duration":0.007288,"end_time":"2023-11-03T19:09:30.155119","exception":false,"start_time":"2023-11-03T19:09:30.147831","status":"completed"},"tags":[],"id":"7a849d13"},"source":["# Entrainement inspiré par mosaicml/mpt-7b-instruct"]},{"cell_type":"markdown","id":"e00d5549","metadata":{"papermill":{"duration":0.005709,"end_time":"2023-11-03T19:09:30.167758","exception":false,"start_time":"2023-11-03T19:09:30.162049","status":"completed"},"tags":[],"id":"e00d5549"},"source":["## Installation des librairies manquantes"]},{"cell_type":"code","execution_count":null,"id":"ac5f5285","metadata":{"papermill":{"duration":72.735112,"end_time":"2023-11-03T19:10:42.907773","exception":false,"start_time":"2023-11-03T19:09:30.172661","status":"completed"},"tags":[],"id":"ac5f5285"},"outputs":[],"source":["! pip install bitsandbytes\n","! pip install einops\n","! pip install peft\n","! pip install trl\n","\n","# Bug selon la version de datasets, besoin d'installer une version plus récente que celle de l'environnement pré-installé :\n","! pip uninstall datasets -y\n","! pip install datasets==2.13.1\n","\n","import datasets\n","datasets.__version__"]},{"cell_type":"markdown","id":"55c58670","metadata":{"papermill":{"duration":0.011597,"end_time":"2023-11-03T19:10:42.931585","exception":false,"start_time":"2023-11-03T19:10:42.919988","status":"completed"},"tags":[],"id":"55c58670"},"source":["## Import des librairies"]},{"cell_type":"code","execution_count":null,"id":"46d628f9","metadata":{"papermill":{"duration":24.437257,"end_time":"2023-11-03T19:11:07.380970","exception":false,"start_time":"2023-11-03T19:10:42.943713","status":"completed"},"tags":[],"id":"46d628f9"},"outputs":[],"source":["import einops\n","import torch\n","from datasets import load_dataset\n","from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments\n","from peft import LoraConfig\n","from trl import SFTTrainer"]},{"cell_type":"markdown","id":"acc377fc","metadata":{"papermill":{"duration":0.011928,"end_time":"2023-11-03T19:11:07.405379","exception":false,"start_time":"2023-11-03T19:11:07.393451","status":"completed"},"tags":[],"id":"acc377fc"},"source":["## Téléchargement du dataset pour le fine tuning"]},{"cell_type":"code","execution_count":null,"id":"66d57569","metadata":{"id":"66d57569","papermill":{"duration":2.120443,"end_time":"2023-11-03T19:11:09.537837","exception":false,"start_time":"2023-11-03T19:11:07.417394","status":"completed"},"tags":[]},"outputs":[],"source":["\n","dataset_name = \"Laurent1/MedQuad-MedicalQnADataset_128tokens_max\"\n","# On fine tune les 5000 premieres questions sinon c'est un peu long...\n","dataset = load_dataset(dataset_name, split='train[:5120]')\n","dataset\n"]},{"cell_type":"markdown","id":"da64af41","metadata":{"id":"da64af41","papermill":{"duration":0.013199,"end_time":"2023-11-03T19:11:09.564367","exception":false,"start_time":"2023-11-03T19:11:09.551168","status":"completed"},"tags":[]},"source":["## Téléchargement du model pre-entrainé et de son tokenizer"]},{"cell_type":"code","execution_count":null,"id":"f0d34d51","metadata":{"id":"f0d34d51","papermill":{"duration":111.325421,"end_time":"2023-11-03T19:13:00.902982","exception":false,"start_time":"2023-11-03T19:11:09.577561","status":"completed"},"tags":[]},"outputs":[],"source":["model_name = \"ibm/mpt-7b-instruct2\"\n","\n","# BitsAndBytes permet le fine tuning avec \"quantification\" pour réduire l'impact mémoire et les calculs\n","bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch.float16,\n",")\n","\n","model = AutoModelForCausalLM.from_pretrained(\n"," \"ibm/mpt-7b-instruct2\",\n"," device_map=\"auto\",\n"," torch_dtype=torch.float16, #torch.bfloat16,\n"," trust_remote_code=True\n"," )\n","\n","tokenizer = AutoTokenizer.from_pretrained(model_name)\n","tokenizer.pad_token = tokenizer.eos_token\n","tokenizer.padding_side = \"right\""]},{"cell_type":"markdown","id":"210f91f0","metadata":{"id":"210f91f0","papermill":{"duration":0.017635,"end_time":"2023-11-03T19:13:00.938752","exception":false,"start_time":"2023-11-03T19:13:00.921117","status":"completed"},"tags":[]},"source":["## Configuration du peft LoRa"]},{"cell_type":"code","execution_count":null,"id":"f319e182","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:00.975007Z","iopub.status.busy":"2023-11-03T19:13:00.974381Z","iopub.status.idle":"2023-11-03T19:13:00.979872Z","shell.execute_reply":"2023-11-03T19:13:00.978679Z"},"id":"f319e182","papermill":{"duration":0.026255,"end_time":"2023-11-03T19:13:00.982184","exception":false,"start_time":"2023-11-03T19:13:00.955929","status":"completed"},"tags":[]},"outputs":[],"source":["lora_alpha = 16\n","lora_dropout = 0.1\n","lora_r = 32\n","\n","peft_config = LoraConfig(\n"," lora_alpha=lora_alpha,\n"," lora_dropout=lora_dropout,\n"," r=lora_r,\n"," bias=\"none\",\n"," task_type=\"CAUSAL_LM\",\n"," target_modules=[\n"," \"Wqkv\",\n"," \"out_proj\",\n"," \"up_proj\",\n"," \"down_proj\",\n"," ]\n",")"]},{"cell_type":"markdown","id":"2865fb08","metadata":{"id":"2865fb08","papermill":{"duration":0.017888,"end_time":"2023-11-03T19:13:01.017800","exception":false,"start_time":"2023-11-03T19:13:00.999912","status":"completed"},"tags":[]},"source":["## Préparation de l'entraineur (Supervised Fine-tuning Trainer)"]},{"cell_type":"markdown","id":"646bb0f6","metadata":{"id":"646bb0f6","papermill":{"duration":0.017992,"end_time":"2023-11-03T19:13:01.053796","exception":false,"start_time":"2023-11-03T19:13:01.035804","status":"completed"},"tags":[]},"source":["Utilisation de [`SFTTrainer` de la librairie TRL](https://huggingface.co/docs/trl/main/en/sft_trainer) qui est un wrapper de Trainer facilite le fine tuning avec LoRa"]},{"cell_type":"code","execution_count":null,"id":"a09bf573","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:01.091811Z","iopub.status.busy":"2023-11-03T19:13:01.091529Z","iopub.status.idle":"2023-11-03T19:13:01.100031Z","shell.execute_reply":"2023-11-03T19:13:01.099221Z"},"papermill":{"duration":0.029762,"end_time":"2023-11-03T19:13:01.101982","exception":false,"start_time":"2023-11-03T19:13:01.072220","status":"completed"},"tags":[],"id":"a09bf573"},"outputs":[],"source":["output_dir = \"/YOUR DIRECTORY\"\n","per_device_train_batch_size = 1\n","gradient_accumulation_steps = 16\n","optim = \"paged_adamw_32bit\"\n","save_steps = 64\n","logging_steps = 64\n","learning_rate = 1e-4\n","max_grad_norm = 0.3\n","max_steps = 1600\n","warmup_ratio = 0.03\n","lr_scheduler_type = \"linear\"\n","\n","training_arguments = TrainingArguments(\n"," output_dir=output_dir,\n"," per_device_train_batch_size=per_device_train_batch_size,\n"," gradient_accumulation_steps=gradient_accumulation_steps,\n"," optim=optim,\n"," save_steps=save_steps,\n"," logging_steps=logging_steps,\n"," learning_rate=learning_rate,\n"," fp16=True,\n"," max_grad_norm=max_grad_norm,\n"," max_steps=max_steps,\n"," warmup_ratio=warmup_ratio,\n"," group_by_length=True,\n"," lr_scheduler_type=lr_scheduler_type,\n"," report_to = 'none',\n"," save_total_limit = 1\n",")"]},{"cell_type":"code","execution_count":null,"id":"65c351f4","metadata":{"id":"65c351f4","papermill":{"duration":2.126939,"end_time":"2023-11-03T19:13:03.247048","exception":false,"start_time":"2023-11-03T19:13:01.120109","status":"completed"},"tags":[]},"outputs":[],"source":["trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=dataset,\n"," peft_config=peft_config,\n"," dataset_text_field=\"text\",\n"," max_seq_length= 512,\n"," tokenizer=tokenizer,\n"," args=training_arguments,\n",")"]},{"cell_type":"markdown","id":"f288484c","metadata":{"id":"f288484c","papermill":{"duration":0.018129,"end_time":"2023-11-03T19:13:03.284278","exception":false,"start_time":"2023-11-03T19:13:03.266149","status":"completed"},"tags":[]},"source":["## Entrainement du model"]},{"cell_type":"code","execution_count":null,"id":"961d4e60","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:03.322815Z","iopub.status.busy":"2023-11-03T19:13:03.322506Z","iopub.status.idle":"2023-11-03T20:53:46.633189Z","shell.execute_reply":"2023-11-03T20:53:46.632279Z"},"id":"961d4e60","outputId":"4db61972-d1d9-4c43-d25a-a122da136bb7","papermill":{"duration":6043.332295,"end_time":"2023-11-03T20:53:46.635133","exception":false,"start_time":"2023-11-03T19:13:03.302838","status":"completed"},"tags":[]},"outputs":[{"name":"stderr","output_type":"stream","text":["You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"]},{"data":{"text/html":["\n","
Step | \n","Training Loss | \n","
---|---|
64 | \n","1.618400 | \n","
128 | \n","1.084200 | \n","
192 | \n","1.021800 | \n","
256 | \n","1.014300 | \n","
320 | \n","0.960500 | \n","
384 | \n","0.905900 | \n","
448 | \n","0.885200 | \n","
512 | \n","0.847400 | \n","
576 | \n","0.889400 | \n","
640 | \n","0.861000 | \n","
704 | \n","0.800400 | \n","
768 | \n","0.768600 | \n","
832 | \n","0.750300 | \n","
896 | \n","0.780200 | \n","
960 | \n","0.762700 | \n","
1024 | \n","0.698600 | \n","
1088 | \n","0.672600 | \n","
1152 | \n","0.693100 | \n","
1216 | \n","0.708900 | \n","
1280 | \n","0.662700 | \n","
1344 | \n","0.630400 | \n","
1408 | \n","0.624600 | \n","
1472 | \n","0.627200 | \n","
1536 | \n","0.628000 | \n","
1600 | \n","0.603300 | \n","
"],"text/plain":["