Adapters
medical
Laurent1 commited on
Commit
2c33a4b
1 Parent(s): 11ff4c3

Upload ibm-mpt-7b-instruct2-QLoRa-medical.ipynb

Browse files
ibm-mpt-7b-instruct2-QLoRa-medical.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"markdown","id":"7a849d13","metadata":{"papermill":{"duration":0.007288,"end_time":"2023-11-03T19:09:30.155119","exception":false,"start_time":"2023-11-03T19:09:30.147831","status":"completed"},"tags":[],"id":"7a849d13"},"source":["# Entrainement inspiré par mosaicml/mpt-7b-instruct"]},{"cell_type":"markdown","id":"e00d5549","metadata":{"papermill":{"duration":0.005709,"end_time":"2023-11-03T19:09:30.167758","exception":false,"start_time":"2023-11-03T19:09:30.162049","status":"completed"},"tags":[],"id":"e00d5549"},"source":["## Installation des librairies manquantes"]},{"cell_type":"code","execution_count":null,"id":"ac5f5285","metadata":{"papermill":{"duration":72.735112,"end_time":"2023-11-03T19:10:42.907773","exception":false,"start_time":"2023-11-03T19:09:30.172661","status":"completed"},"tags":[],"id":"ac5f5285"},"outputs":[],"source":["! pip install bitsandbytes\n","! pip install einops\n","! pip install peft\n","! pip install trl\n","\n","# Bug selon la version de datasets, besoin d'installer une version plus récente que celle de l'environnement pré-installé :\n","! pip uninstall datasets -y\n","! pip install datasets==2.13.1\n","\n","import datasets\n","datasets.__version__"]},{"cell_type":"markdown","id":"55c58670","metadata":{"papermill":{"duration":0.011597,"end_time":"2023-11-03T19:10:42.931585","exception":false,"start_time":"2023-11-03T19:10:42.919988","status":"completed"},"tags":[],"id":"55c58670"},"source":["## Import des librairies"]},{"cell_type":"code","execution_count":null,"id":"46d628f9","metadata":{"papermill":{"duration":24.437257,"end_time":"2023-11-03T19:11:07.380970","exception":false,"start_time":"2023-11-03T19:10:42.943713","status":"completed"},"tags":[],"id":"46d628f9"},"outputs":[],"source":["import einops\n","import torch\n","from datasets import load_dataset\n","from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments\n","from peft import LoraConfig\n","from trl import SFTTrainer"]},{"cell_type":"markdown","id":"acc377fc","metadata":{"papermill":{"duration":0.011928,"end_time":"2023-11-03T19:11:07.405379","exception":false,"start_time":"2023-11-03T19:11:07.393451","status":"completed"},"tags":[],"id":"acc377fc"},"source":["## Téléchargement du dataset pour le fine tuning"]},{"cell_type":"code","execution_count":null,"id":"66d57569","metadata":{"id":"66d57569","papermill":{"duration":2.120443,"end_time":"2023-11-03T19:11:09.537837","exception":false,"start_time":"2023-11-03T19:11:07.417394","status":"completed"},"tags":[]},"outputs":[],"source":["\n","dataset_name = \"Laurent1/MedQuad-MedicalQnADataset_128tokens_max\"\n","# On fine tune les 5000 premieres questions sinon c'est un peu long...\n","dataset = load_dataset(dataset_name, split='train[:5120]')\n","dataset\n"]},{"cell_type":"markdown","id":"da64af41","metadata":{"id":"da64af41","papermill":{"duration":0.013199,"end_time":"2023-11-03T19:11:09.564367","exception":false,"start_time":"2023-11-03T19:11:09.551168","status":"completed"},"tags":[]},"source":["## Téléchargement du model pre-entrainé et de son tokenizer"]},{"cell_type":"code","execution_count":null,"id":"f0d34d51","metadata":{"id":"f0d34d51","papermill":{"duration":111.325421,"end_time":"2023-11-03T19:13:00.902982","exception":false,"start_time":"2023-11-03T19:11:09.577561","status":"completed"},"tags":[]},"outputs":[],"source":["model_name = \"ibm/mpt-7b-instruct2\"\n","\n","# BitsAndBytes permet le fine tuning avec \"quantification\" pour réduire l'impact mémoire et les calculs\n","bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch.float16,\n",")\n","\n","model = AutoModelForCausalLM.from_pretrained(\n"," \"ibm/mpt-7b-instruct2\",\n"," device_map=\"auto\",\n"," torch_dtype=torch.float16, #torch.bfloat16,\n"," trust_remote_code=True\n"," )\n","\n","tokenizer = AutoTokenizer.from_pretrained(model_name)\n","tokenizer.pad_token = tokenizer.eos_token\n","tokenizer.padding_side = \"right\""]},{"cell_type":"markdown","id":"210f91f0","metadata":{"id":"210f91f0","papermill":{"duration":0.017635,"end_time":"2023-11-03T19:13:00.938752","exception":false,"start_time":"2023-11-03T19:13:00.921117","status":"completed"},"tags":[]},"source":["## Configuration du peft LoRa"]},{"cell_type":"code","execution_count":null,"id":"f319e182","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:00.975007Z","iopub.status.busy":"2023-11-03T19:13:00.974381Z","iopub.status.idle":"2023-11-03T19:13:00.979872Z","shell.execute_reply":"2023-11-03T19:13:00.978679Z"},"id":"f319e182","papermill":{"duration":0.026255,"end_time":"2023-11-03T19:13:00.982184","exception":false,"start_time":"2023-11-03T19:13:00.955929","status":"completed"},"tags":[]},"outputs":[],"source":["lora_alpha = 16\n","lora_dropout = 0.1\n","lora_r = 32\n","\n","peft_config = LoraConfig(\n"," lora_alpha=lora_alpha,\n"," lora_dropout=lora_dropout,\n"," r=lora_r,\n"," bias=\"none\",\n"," task_type=\"CAUSAL_LM\",\n"," target_modules=[\n"," \"Wqkv\",\n"," \"out_proj\",\n"," \"up_proj\",\n"," \"down_proj\",\n"," ]\n",")"]},{"cell_type":"markdown","id":"2865fb08","metadata":{"id":"2865fb08","papermill":{"duration":0.017888,"end_time":"2023-11-03T19:13:01.017800","exception":false,"start_time":"2023-11-03T19:13:00.999912","status":"completed"},"tags":[]},"source":["## Préparation de l'entraineur (Supervised Fine-tuning Trainer)"]},{"cell_type":"markdown","id":"646bb0f6","metadata":{"id":"646bb0f6","papermill":{"duration":0.017992,"end_time":"2023-11-03T19:13:01.053796","exception":false,"start_time":"2023-11-03T19:13:01.035804","status":"completed"},"tags":[]},"source":["Utilisation de [`SFTTrainer` de la librairie TRL](https://huggingface.co/docs/trl/main/en/sft_trainer) qui est un wrapper de Trainer facilite le fine tuning avec LoRa"]},{"cell_type":"code","execution_count":null,"id":"a09bf573","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:01.091811Z","iopub.status.busy":"2023-11-03T19:13:01.091529Z","iopub.status.idle":"2023-11-03T19:13:01.100031Z","shell.execute_reply":"2023-11-03T19:13:01.099221Z"},"papermill":{"duration":0.029762,"end_time":"2023-11-03T19:13:01.101982","exception":false,"start_time":"2023-11-03T19:13:01.072220","status":"completed"},"tags":[],"id":"a09bf573"},"outputs":[],"source":["output_dir = \"/YOUR DIRECTORY\"\n","per_device_train_batch_size = 1\n","gradient_accumulation_steps = 16\n","optim = \"paged_adamw_32bit\"\n","save_steps = 64\n","logging_steps = 64\n","learning_rate = 1e-4\n","max_grad_norm = 0.3\n","max_steps = 1600\n","warmup_ratio = 0.03\n","lr_scheduler_type = \"linear\"\n","\n","training_arguments = TrainingArguments(\n"," output_dir=output_dir,\n"," per_device_train_batch_size=per_device_train_batch_size,\n"," gradient_accumulation_steps=gradient_accumulation_steps,\n"," optim=optim,\n"," save_steps=save_steps,\n"," logging_steps=logging_steps,\n"," learning_rate=learning_rate,\n"," fp16=True,\n"," max_grad_norm=max_grad_norm,\n"," max_steps=max_steps,\n"," warmup_ratio=warmup_ratio,\n"," group_by_length=True,\n"," lr_scheduler_type=lr_scheduler_type,\n"," report_to = 'none',\n"," save_total_limit = 1\n",")"]},{"cell_type":"code","execution_count":null,"id":"65c351f4","metadata":{"id":"65c351f4","papermill":{"duration":2.126939,"end_time":"2023-11-03T19:13:03.247048","exception":false,"start_time":"2023-11-03T19:13:01.120109","status":"completed"},"tags":[]},"outputs":[],"source":["trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=dataset,\n"," peft_config=peft_config,\n"," dataset_text_field=\"text\",\n"," max_seq_length= 512,\n"," tokenizer=tokenizer,\n"," args=training_arguments,\n",")"]},{"cell_type":"markdown","id":"f288484c","metadata":{"id":"f288484c","papermill":{"duration":0.018129,"end_time":"2023-11-03T19:13:03.284278","exception":false,"start_time":"2023-11-03T19:13:03.266149","status":"completed"},"tags":[]},"source":["## Entrainement du model"]},{"cell_type":"code","execution_count":null,"id":"961d4e60","metadata":{"execution":{"iopub.execute_input":"2023-11-03T19:13:03.322815Z","iopub.status.busy":"2023-11-03T19:13:03.322506Z","iopub.status.idle":"2023-11-03T20:53:46.633189Z","shell.execute_reply":"2023-11-03T20:53:46.632279Z"},"id":"961d4e60","outputId":"4db61972-d1d9-4c43-d25a-a122da136bb7","papermill":{"duration":6043.332295,"end_time":"2023-11-03T20:53:46.635133","exception":false,"start_time":"2023-11-03T19:13:03.302838","status":"completed"},"tags":[]},"outputs":[{"name":"stderr","output_type":"stream","text":["You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"]},{"data":{"text/html":["\n"," <div>\n"," \n"," <progress value='1600' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n"," [1600/1600 1:40:32, Epoch 5/5]\n"," </div>\n"," <table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: left;\">\n"," <th>Step</th>\n"," <th>Training Loss</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <td>64</td>\n"," <td>1.618400</td>\n"," </tr>\n"," <tr>\n"," <td>128</td>\n"," <td>1.084200</td>\n"," </tr>\n"," <tr>\n"," <td>192</td>\n"," <td>1.021800</td>\n"," </tr>\n"," <tr>\n"," <td>256</td>\n"," <td>1.014300</td>\n"," </tr>\n"," <tr>\n"," <td>320</td>\n"," <td>0.960500</td>\n"," </tr>\n"," <tr>\n"," <td>384</td>\n"," <td>0.905900</td>\n"," </tr>\n"," <tr>\n"," <td>448</td>\n"," <td>0.885200</td>\n"," </tr>\n"," <tr>\n"," <td>512</td>\n"," <td>0.847400</td>\n"," </tr>\n"," <tr>\n"," <td>576</td>\n"," <td>0.889400</td>\n"," </tr>\n"," <tr>\n"," <td>640</td>\n"," <td>0.861000</td>\n"," </tr>\n"," <tr>\n"," <td>704</td>\n"," <td>0.800400</td>\n"," </tr>\n"," <tr>\n"," <td>768</td>\n"," <td>0.768600</td>\n"," </tr>\n"," <tr>\n"," <td>832</td>\n"," <td>0.750300</td>\n"," </tr>\n"," <tr>\n"," <td>896</td>\n"," <td>0.780200</td>\n"," </tr>\n"," <tr>\n"," <td>960</td>\n"," <td>0.762700</td>\n"," </tr>\n"," <tr>\n"," <td>1024</td>\n"," <td>0.698600</td>\n"," </tr>\n"," <tr>\n"," <td>1088</td>\n"," <td>0.672600</td>\n"," </tr>\n"," <tr>\n"," <td>1152</td>\n"," <td>0.693100</td>\n"," </tr>\n"," <tr>\n"," <td>1216</td>\n"," <td>0.708900</td>\n"," </tr>\n"," <tr>\n"," <td>1280</td>\n"," <td>0.662700</td>\n"," </tr>\n"," <tr>\n"," <td>1344</td>\n"," <td>0.630400</td>\n"," </tr>\n"," <tr>\n"," <td>1408</td>\n"," <td>0.624600</td>\n"," </tr>\n"," <tr>\n"," <td>1472</td>\n"," <td>0.627200</td>\n"," </tr>\n"," <tr>\n"," <td>1536</td>\n"," <td>0.628000</td>\n"," </tr>\n"," <tr>\n"," <td>1600</td>\n"," <td>0.603300</td>\n"," </tr>\n"," </tbody>\n","</table><p>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"data":{"text/plain":["TrainOutput(global_step=1600, training_loss=0.819993417263031, metrics={'train_runtime': 6042.1301, 'train_samples_per_second': 4.237, 'train_steps_per_second': 0.265, 'total_flos': 1.0172436457734144e+17, 'train_loss': 0.819993417263031, 'epoch': 5.0})"]},"execution_count":8,"metadata":{},"output_type":"execute_result"}],"source":["trainer.train()"]},{"cell_type":"code","execution_count":null,"id":"9726dd41","metadata":{"execution":{"iopub.execute_input":"2023-11-03T20:53:46.718306Z","iopub.status.busy":"2023-11-03T20:53:46.717562Z","iopub.status.idle":"2023-11-03T20:53:57.048468Z","shell.execute_reply":"2023-11-03T20:53:57.047305Z"},"papermill":{"duration":10.351815,"end_time":"2023-11-03T20:53:57.050634","exception":false,"start_time":"2023-11-03T20:53:46.698819","status":"completed"},"tags":[],"id":"9726dd41","outputId":"9cc7d650-0e20-43df-ed15-f910be03418c"},"outputs":[{"name":"stderr","output_type":"stream","text":["/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1417: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )\n"," warnings.warn(\n"]},{"name":"stdout","output_type":"stream","text":["Below is an instruction from Human. Write a response.\n"," ### Instruction:\n"," How to diagnose Parasites - Baylisascaris infection?\n"," ### Response:\n"," The infection is diagnosed by identification of the parasite in stool samples.\n"," \n","The infection is usually diagnosed after the person has been hospitalized and the diagnosis is confirmed by identification of the parasite in stool samples.\n"," \n","The stool samples are sent to a laboratory for examination.\n"," ### Instruction:\n"," How to prevent and control Parasites - Baylisascaris infection?\n"," ### Response:\n"," The best way to prevent infection is to avoid contact with raccoons\n"]}],"source":["\n","text = \"Below is an instruction from Human. Write a response.\\n ### Instruction:\\n How to diagnose Parasites - Baylisascaris infection ?\\n ### Response:\"\n","inputs = tokenizer(text, return_tensors=\"pt\").to('cuda')\n","out = model.generate(**inputs, max_new_tokens=100)\n","\n","print(tokenizer.decode(out[0]))\n"]}],"metadata":{"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.13"},"papermill":{"default_parameters":{},"duration":6274.702649,"end_time":"2023-11-03T20:54:00.492244","environment_variables":{},"exception":null,"input_path":"__notebook__.ipynb","output_path":"__notebook__.ipynb","parameters":{},"start_time":"2023-11-03T19:09:25.789595","version":"2.4.0"},"colab":{"provenance":[]}},"nbformat":4,"nbformat_minor":5}