{"cells":[{"cell_type":"markdown","id":"7a849d13","metadata":{"papermill":{"duration":0.007288,"end_time":"2023-11-03T19:09:30.155119","exception":false,"start_time":"2023-11-03T19:09:30.147831","status":"completed"},"tags":[],"id":"7a849d13"},"source":["# Entrainement inspiré par mosaicml/mpt-7b-instruct"]},{"cell_type":"markdown","id":"e00d5549","metadata":{"papermill":{"duration":0.005709,"end_time":"2023-11-03T19:09:30.167758","exception":false,"start_time":"2023-11-03T19:09:30.162049","status":"completed"},"tags":[],"id":"e00d5549"},"source":["## Installation des librairies manquantes"]},{"cell_type":"code","execution_count":null,"id":"ac5f5285","metadata":{"papermill":{"duration":72.735112,"end_time":"2023-11-03T19:10:42.907773","exception":false,"start_time":"2023-11-03T19:09:30.172661","status":"completed"},"tags":[],"id":"ac5f5285"},"outputs":[],"source":["! pip install bitsandbytes\n","! pip install einops\n","! pip install peft\n","! pip install trl\n","\n","# Bug selon la version de datasets, besoin d'installer une version plus récente que celle de l'environnement pré-installé :\n","! pip uninstall datasets -y\n","! pip install datasets==2.13.1\n","\n","import datasets\n","datasets.__version__"]},{"cell_type":"markdown","id":"55c58670","metadata":{"papermill":{"duration":0.011597,"end_time":"2023-11-03T19:10:42.931585","exception":false,"start_time":"2023-11-03T19:10:42.919988","status":"completed"},"tags":[],"id":"55c58670"},"source":["## Import des librairies"]},{"cell_type":"code","execution_count":null,"id":"46d628f9","metadata":{"papermill":{"duration":24.437257,"end_time":"2023-11-03T19:11:07.380970","exception":false,"start_time":"2023-11-03T19:10:42.943713","status":"completed"},"tags":[],"id":"46d628f9"},"outputs":[],"source":["import einops\n","import torch\n","from datasets import load_dataset\n","from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments\n","from peft import LoraConfig\n","from trl import SFTTrainer"]},{"cell_type":"markdown","id":"acc377fc","metadata":{"papermill":{"duration":0.011928,"end_time":"2023-11-03T19:11:07.405379","exception":false,"start_time":"2023-11-03T19:11:07.393451","status":"completed"},"tags":[],"id":"acc377fc"},"source":["## Téléchargement du dataset pour le fine tuning"]},{"cell_type":"code","execution_count":null,"id":"66d57569","metadata":{"id":"66d57569","papermill":{"duration":2.120443,"end_time":"2023-11-03T19:11:09.537837","exception":false,"start_time":"2023-11-03T19:11:07.417394","status":"completed"},"tags":[]},"outputs":[],"source":["\n","dataset_name = \"Laurent1/MedQuad-MedicalQnADataset_128tokens_max\"\n","# On fine tune les 5000 premieres questions sinon c'est un peu long...\n","dataset = load_dataset(dataset_name, split='train[:5120]')\n","dataset\n"]},{"cell_type":"markdown","id":"da64af41","metadata":{"id":"da64af41","papermill":{"duration":0.013199,"end_time":"2023-11-03T19:11:09.564367","exception":false,"start_time":"2023-11-03T19:11:09.551168","status":"completed"},"tags":[]},"source":["## Téléchargement du model pre-entrainé et de son tokenizer"]},{"cell_type":"code","execution_count":null,"id":"f0d34d51","metadata":{"id":"f0d34d51","papermill":{"duration":111.325421,"end_time":"2023-11-03T19:13:00.902982","exception":false,"start_time":"2023-11-03T19:11:09.577561","status":"completed"},"tags":[]},"outputs":[],"source":["model_name = \"ibm/mpt-7b-instruct2\"\n","\n","# BitsAndBytes permet le fine tuning avec \"quantification\" pour réduire l'impact mémoire et les calculs\n","bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch.float16,\n",")\n","\n","model = AutoModelForCausalLM.from_pretrained(\n"," \"ibm/mpt-7b-instruct2\",\n"," device_map=\"auto\",\n"," torch_dtype=torch.float16, #torch.bfloat16,\n"," trust_remote_code=True\n"," )\n","\n","tokenizer = AutoTokenizer.from_pretrained(model_name)\n","tokenizer.pad_token = tokenizer.eos_token\n","tokenizer.padding_side = \"right\""]},{"cell_type":"markdown","id":"210f91f0","metadata":{"id":"210f91f0","papermill":{"duration":0.017635,"end_time":"2023-11-03T19:13:00.938752","exception":false,"start_time":"2023-11-03T19:13:00.921117","status":"completed"},"tags":[]},"source":["## Configuration du peft LoRa"]},{"cell_type":"code","execution_count":null,"id":"f319e182","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:00.975007Z","iopub.status.busy":"2023-11-03T19:13:00.974381Z","iopub.status.idle":"2023-11-03T19:13:00.979872Z","shell.execute_reply":"2023-11-03T19:13:00.978679Z"},"id":"f319e182","papermill":{"duration":0.026255,"end_time":"2023-11-03T19:13:00.982184","exception":false,"start_time":"2023-11-03T19:13:00.955929","status":"completed"},"tags":[]},"outputs":[],"source":["lora_alpha = 16\n","lora_dropout = 0.1\n","lora_r = 32\n","\n","peft_config = LoraConfig(\n"," lora_alpha=lora_alpha,\n"," lora_dropout=lora_dropout,\n"," r=lora_r,\n"," bias=\"none\",\n"," task_type=\"CAUSAL_LM\",\n"," target_modules=[\n"," \"Wqkv\",\n"," \"out_proj\",\n"," \"up_proj\",\n"," \"down_proj\",\n"," ]\n",")"]},{"cell_type":"markdown","id":"2865fb08","metadata":{"id":"2865fb08","papermill":{"duration":0.017888,"end_time":"2023-11-03T19:13:01.017800","exception":false,"start_time":"2023-11-03T19:13:00.999912","status":"completed"},"tags":[]},"source":["## Préparation de l'entraineur (Supervised Fine-tuning Trainer)"]},{"cell_type":"markdown","id":"646bb0f6","metadata":{"id":"646bb0f6","papermill":{"duration":0.017992,"end_time":"2023-11-03T19:13:01.053796","exception":false,"start_time":"2023-11-03T19:13:01.035804","status":"completed"},"tags":[]},"source":["Utilisation de [`SFTTrainer` de la librairie TRL](https://huggingface.co/docs/trl/main/en/sft_trainer) qui est un wrapper de Trainer facilite le fine tuning avec LoRa"]},{"cell_type":"code","execution_count":null,"id":"a09bf573","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:01.091811Z","iopub.status.busy":"2023-11-03T19:13:01.091529Z","iopub.status.idle":"2023-11-03T19:13:01.100031Z","shell.execute_reply":"2023-11-03T19:13:01.099221Z"},"papermill":{"duration":0.029762,"end_time":"2023-11-03T19:13:01.101982","exception":false,"start_time":"2023-11-03T19:13:01.072220","status":"completed"},"tags":[],"id":"a09bf573"},"outputs":[],"source":["output_dir = \"/YOUR DIRECTORY\"\n","per_device_train_batch_size = 1\n","gradient_accumulation_steps = 16\n","optim = \"paged_adamw_32bit\"\n","save_steps = 64\n","logging_steps = 64\n","learning_rate = 1e-4\n","max_grad_norm = 0.3\n","max_steps = 1600\n","warmup_ratio = 0.03\n","lr_scheduler_type = \"linear\"\n","\n","training_arguments = TrainingArguments(\n"," output_dir=output_dir,\n"," per_device_train_batch_size=per_device_train_batch_size,\n"," gradient_accumulation_steps=gradient_accumulation_steps,\n"," optim=optim,\n"," save_steps=save_steps,\n"," logging_steps=logging_steps,\n"," learning_rate=learning_rate,\n"," fp16=True,\n"," max_grad_norm=max_grad_norm,\n"," max_steps=max_steps,\n"," warmup_ratio=warmup_ratio,\n"," group_by_length=True,\n"," lr_scheduler_type=lr_scheduler_type,\n"," report_to = 'none',\n"," save_total_limit = 1\n",")"]},{"cell_type":"code","execution_count":null,"id":"65c351f4","metadata":{"id":"65c351f4","papermill":{"duration":2.126939,"end_time":"2023-11-03T19:13:03.247048","exception":false,"start_time":"2023-11-03T19:13:01.120109","status":"completed"},"tags":[]},"outputs":[],"source":["trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=dataset,\n"," peft_config=peft_config,\n"," dataset_text_field=\"text\",\n"," max_seq_length= 512,\n"," tokenizer=tokenizer,\n"," args=training_arguments,\n",")"]},{"cell_type":"markdown","id":"f288484c","metadata":{"id":"f288484c","papermill":{"duration":0.018129,"end_time":"2023-11-03T19:13:03.284278","exception":false,"start_time":"2023-11-03T19:13:03.266149","status":"completed"},"tags":[]},"source":["## Entrainement du model"]},{"cell_type":"code","execution_count":null,"id":"961d4e60","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:03.322815Z","iopub.status.busy":"2023-11-03T19:13:03.322506Z","iopub.status.idle":"2023-11-03T20:53:46.633189Z","shell.execute_reply":"2023-11-03T20:53:46.632279Z"},"id":"961d4e60","outputId":"4db61972-d1d9-4c43-d25a-a122da136bb7","papermill":{"duration":6043.332295,"end_time":"2023-11-03T20:53:46.635133","exception":false,"start_time":"2023-11-03T19:13:03.302838","status":"completed"},"tags":[]},"outputs":[{"name":"stderr","output_type":"stream","text":["You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"]},{"data":{"text/html":["\n","
\n"," \n"," \n"," [1600/1600 1:40:32, Epoch 5/5]\n","
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
StepTraining Loss
641.618400
1281.084200
1921.021800
2561.014300
3200.960500
3840.905900
4480.885200
5120.847400
5760.889400
6400.861000
7040.800400
7680.768600
8320.750300
8960.780200
9600.762700
10240.698600
10880.672600
11520.693100
12160.708900
12800.662700
13440.630400
14080.624600
14720.627200
15360.628000
16000.603300

"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/plain":["TrainOutput(global_step=1600, training_loss=0.819993417263031, metrics={'train_runtime': 6042.1301, 'train_samples_per_second': 4.237, 'train_steps_per_second': 0.265, 'total_flos': 1.0172436457734144e+17, 'train_loss': 0.819993417263031, 'epoch': 5.0})"]},"execution_count":8,"metadata":{},"output_type":"execute_result"}],"source":["trainer.train()"]},{"cell_type":"code","execution_count":null,"id":"9726dd41","metadata":{"execution":{"iopub.execute_input":"2023-11-03T20:53:46.718306Z","iopub.status.busy":"2023-11-03T20:53:46.717562Z","iopub.status.idle":"2023-11-03T20:53:57.048468Z","shell.execute_reply":"2023-11-03T20:53:57.047305Z"},"papermill":{"duration":10.351815,"end_time":"2023-11-03T20:53:57.050634","exception":false,"start_time":"2023-11-03T20:53:46.698819","status":"completed"},"tags":[],"id":"9726dd41","outputId":"9cc7d650-0e20-43df-ed15-f910be03418c"},"outputs":[{"name":"stderr","output_type":"stream","text":["/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1417: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )\n"," warnings.warn(\n"]},{"name":"stdout","output_type":"stream","text":["Below is an instruction from Human. Write a response.\n"," ### Instruction:\n"," How to diagnose Parasites - Baylisascaris infection?\n"," ### Response:\n"," The infection is diagnosed by identification of the parasite in stool samples.\n"," \n","The infection is usually diagnosed after the person has been hospitalized and the diagnosis is confirmed by identification of the parasite in stool samples.\n"," \n","The stool samples are sent to a laboratory for examination.\n"," ### Instruction:\n"," How to prevent and control Parasites - Baylisascaris infection?\n"," ### Response:\n"," The best way to prevent infection is to avoid contact with raccoons\n"]}],"source":["\n","text = \"Below is an instruction from Human. Write a response.\\n ### Instruction:\\n How to diagnose Parasites - Baylisascaris infection ?\\n ### Response:\"\n","inputs = tokenizer(text, return_tensors=\"pt\").to('cuda')\n","out = model.generate(**inputs, max_new_tokens=100)\n","\n","print(tokenizer.decode(out[0]))\n"]}],"metadata":{"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.13"},"papermill":{"default_parameters":{},"duration":6274.702649,"end_time":"2023-11-03T20:54:00.492244","environment_variables":{},"exception":null,"input_path":"__notebook__.ipynb","output_path":"__notebook__.ipynb","parameters":{},"start_time":"2023-11-03T19:09:25.789595","version":"2.4.0"},"colab":{"provenance":[]}},"nbformat":4,"nbformat_minor":5}