{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "\n", "# MIDIstral Pixtral 12B Fine-Tuning Code\n", "\n", "***\n", "\n", "## Based upon fine-tuning code by Tomasz Stankiewicz\n", "\n", "## https://github.com/tomstaan/Clarivex-Pixtral-12B\n", "\n", "***\n", "\n", "### Project Los Angeles\n", "### Tegridy Code 2024\n", "\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!python3 -m pip install --upgrade pip -q\n", "!pip3 install -U transformers\n", "!pip3 install -q accelerate datasets peft bitsandbytes hf_transfer flash_attn tensorboard\n", "!pip3 install ipywidgets\n", "!pip3 install --upgrade jinja2\n", "!pip3 install --upgrade peft\n", "!pip3 install -U pillow\n", "!pip3 install pip install tf-keras\n", "\n", "# Can be a good idea to re-start the kernel after this" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!sudo pip3 install tf-keras" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!sudo pip install -U numpy==1.26.1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Enable fast weights download and upload\n", "import os\n", "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Download model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from PIL import Image\n", "from transformers import AutoProcessor, LlavaForConditionalGeneration\n", "from transformers import BitsAndBytesConfig\n", "\n", "model_id = \"mistral-community/pixtral-12b\"\n", "\n", "model = LlavaForConditionalGeneration.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.bfloat16,\n", " device_map='auto',\n", " #attn_implementation=\"sdpa\",\n", ")\n", "\n", "processor = AutoProcessor.from_pretrained(model_id)\n", "\n", "# Extract the tokenizer from the processor\n", "tokenizer = processor.tokenizer\n", "\n", "# Set the padding side to 'left' for Flash Attention compatibility\n", "tokenizer.padding_side = \"left\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Chat Template" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "CHAT_TEMPLATE = \"\"\"\n", "{%- for message in messages %} \n", " {%- if message.role == \"user\" %} \n", " [INST] \n", " {%- for item in message.content %} \n", " {%- if item.type == \"text\" %} \n", " {{ item.text }} \n", " {%- elif item.type == \"image\" %} \n", " \\n[IMG] \n", " {%- endif %} \n", " {%- endfor %} \n", " [/INST] \n", " {%- elif message.role == \"assistant\" %} \n", " {%- for item in message.content %} \n", " {%- if item.type == \"text\" %} \n", " {{ item.text }} \n", " {%- endif %} \n", " {%- endfor %} \n", " \n", " {%- endif %} \n", "{%- endfor %} \n", "\"\"\"\n", "\n", "# Set the chat template for the tokenizer\n", "processor.chat_template = CHAT_TEMPLATE.replace(' ', '')\n", "\n", "processor.tokenizer.pad_token = processor.tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Example conversation input with user and assistant roles\n", "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": \"Please describe the song music in detail. Thank you.\"},\n", " {\"type\": \"image\"}\n", " ]\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": \"The song 'Man In Black' by Johnny Cash in key A# has fast tempo and average pace with Acoustic Guitar(steel) lead, accompanying Acoustic Grand and predominant Acoustic Snare drums\"}\n", " ]\n", " }\n", "]\n", "\n", "# Apply the chat template to format the messages\n", "formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)\n", "\n", "# Output the formatted text\n", "print(\"Formatted text:\\n\", formatted_text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Download dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import io\n", "from datasets import load_dataset\n", "\n", "def deserialize_image(byte_data):\n", " img_byte_arr = io.BytesIO(byte_data)\n", " img = Image.open(img_byte_arr)\n", " return img\n", "\n", "dataset = load_dataset(\"asigalov61/MIDIstral\", split='train').train_test_split(test_size=0.001)\n", "\n", "# Access the training and test sets\n", "train_dataset = dataset[\"train\"]\n", "eval_dataset = dataset[\"test\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(train_dataset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "eval_dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation before fine-tuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from PIL import Image\n", "from torchvision.transforms.functional import to_pil_image, resize\n", "\n", "def run_model_evaluation(model, dataset, num_samples=None, device='cuda', constant_query=None):\n", " model.eval()\n", " results = []\n", "\n", " # Limit the dataset if a specific number of samples is provided\n", " if num_samples is not None:\n", " dataset = torch.utils.data.Subset(dataset, range(num_samples))\n", "\n", " for example in dataset:\n", " image = deserialize_image(example[\"image\"])\n", " if constant_query is None:\n", " query = example[\"query\"][\"en\"]\n", " else:\n", " query = constant_query # Use the constant query if provided\n", " \n", " # Display a reduced size version of the image\n", " pil_image = image\n", " aspect_ratio = pil_image.width / pil_image.height\n", " new_width = 300\n", " new_height = int(new_width / aspect_ratio)\n", " display_image = resize(pil_image, (new_height, new_width))\n", " display_image.show() # This will open the image in the default image viewer\n", "\n", " # Construct the message template\n", " messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " # {\"type\": \"text\", \"text\": \"Answer briefly.\"},\n", " {\"type\": \"text\", \"text\": query},\n", " {\"type\": \"image\"}, # YOU CAN COMMENT THIS OUT IF THERE ARE NO IMAGES\n", " # {\"type\": \"image\"}, # ADD A SECOND IMAGE!!! Note that the text is also possible here.\n", " ]\n", " }\n", " ]\n", "\n", " # Apply the chat template to preprocess input\n", " formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n", " print(f\"Formatted prompt: {formatted_prompt}\")\n", " text = processor.apply_chat_template(messages, add_generation_prompt=True)\n", " inputs = processor(text=[text.strip()], images=[image], return_tensors=\"pt\", padding=True).to(device)\n", " # inputs = processor(text=[text.strip()], images=[image, image2], return_tensors=\"pt\" padding=True).to(device)\n", "\n", " # Generate output from the model\n", " generated_ids = model.generate(**inputs, max_new_tokens=64)\n", " generated_texts = processor.batch_decode(generated_ids[:, inputs[\"input_ids\"].shape[-1]:])\n", "\n", " print(f\"Prediction: {generated_texts[0]}\\n\")\n", "\n", " results.append(generated_texts[0]) # Store the result\n", "\n", " return results\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Usage\n", "eval_results_before_fine_tuning = run_model_evaluation(model, \n", " eval_dataset, \n", " num_samples=2, \n", " device='cuda', \n", " constant_query='Please describe the song music in detail. Thank you.')\n", "\n", "print('eval_results_before_fine_tuning:', eval_results_before_fine_tuning)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "class MyDataCollator:\n", " def __init__(self, processor):\n", " self.processor = processor\n", "\n", " def __call__(self, examples):\n", " texts = []\n", " images = []\n", " assistant_responses = [] # To track assistant responses for proper masking\n", " for example in examples:\n", " image = deserialize_image(example[\"image\"])\n", " question = example[\"question\"] # for chess dataset\n", " answer = example[\"answer\"] # for chess dataset\n", "\n", " messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": question},\n", " {\"type\": \"image\"}, # Images after the text.\n", " ]\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": answer}\n", " ]\n", " }\n", " ]\n", "\n", " # Convert messages to the desired text format using processor's template\n", " text = self.processor.apply_chat_template(messages, add_generation_prompt=False)\n", "\n", " texts.append(text.strip())\n", " images.append([image])\n", " assistant_responses.append(answer) # Track assistant's response for later use\n", "\n", " # Tokenize and process batch\n", " batch = self.processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n", "\n", " # Prepare labels; we will mask non-assistant tokens for generation\n", " labels = batch[\"input_ids\"].clone() \n", "\n", " # For each example, find assistant tokens and mask everything else\n", " for i, (input_ids, assistant_response) in enumerate(zip(batch[\"input_ids\"], assistant_responses)):\n", " # Tokenize just the assistant response\n", " assistant_tokens = self.processor.tokenizer(assistant_response, return_tensors=\"pt\")[\"input_ids\"][0]\n", "\n", " # Find where the assistant tokens start in the input sequence\n", " start_idx = self.find_subsequence(input_ids, assistant_tokens)\n", "\n", " if start_idx is not None:\n", " # Mask everything except the assistant tokens\n", " labels[i, :start_idx] = -100 # Ignore everything before the assistant's response\n", " labels[i, start_idx + len(assistant_tokens):] = -100 # Ignore everything after\n", "\n", " # Assign masked labels back to the batch\n", " batch[\"labels\"] = labels\n", "\n", " return batch\n", " \n", " def find_subsequence(self, sequence, subsequence):\n", " \"\"\"\n", " Find the start index of a subsequence (assistant tokens) in a sequence (input tokens).\n", " \"\"\"\n", " seq_len = len(sequence)\n", " sub_len = len(subsequence)\n", "\n", " for i in range(seq_len - sub_len + 1):\n", " if torch.equal(sequence[i:i + sub_len], subsequence):\n", " return i\n", " return None\n", " \n", "data_collator = MyDataCollator(processor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Select a small batch of examples (e.g., 2 examples for quick testing)\n", "sample_batch = [train_dataset[i] for i in range(2)]\n", "\n", "# Call the data collator with the sample batch to process it\n", "processed_batch = data_collator(sample_batch)\n", "\n", "# Print the processed batch keys to check what's inside\n", "print(\"Processed batch keys:\", processed_batch.keys())\n", "\n", "# Print out the texts after applying the chat template\n", "print(\"\\nTokenized input IDs (before padding):\")\n", "print(processed_batch[\"input_ids\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "processed_batch[\"input_ids\"].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from peft import LoraConfig\n", "\n", "lora_config = LoraConfig(\n", " r=32, # Rank (usually 8, 16, or 32 depending on model size and needs)\n", " lora_alpha=32, # Scaling factor for the low-rank updates\n", " use_rslora=True, # Use RS LoRA for regularization\n", " target_modules=\"all-linear\", # Target specific modules (e.g., linear layers)\n", " # modules_to_save=['lm_head','embed_tokens'],\n", " lora_dropout=0.1, # Dropout for low-rank adapter layers\n", " bias=\"none\", # Bias in adapter layers: \"none\", \"all\", \"lora_only\"\n", " task_type=\"CAUSAL_LM\" # Task type: \"CAUSAL_LM\", \"SEQ_2_SEQ_LM\", or \"TOKEN_CLS\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from peft import get_peft_model\n", "\n", "model=get_peft_model(model, lora_config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments, Trainer\n", "\n", "# for main fine-tuning\n", "epochs = 1\n", "lr = 3e-5\n", "schedule = \"constant\"\n", "\n", "# Optional, for annealing\n", "# epochs = 0.4\n", "# lr = 3e-5\n", "# schedule = \"linear\"\n", "\n", "run_name = f\"MIDIstral-{lr}_lr-{epochs}_epochs-{schedule}_schedule\"\n", "\n", "training_args = TrainingArguments(\n", " # max_steps=1, # Optional: run only for one step, useful for debugging\n", " num_train_epochs=epochs, # Number of training epochs\n", " per_device_train_batch_size=8, # Batch size per device for training\n", " per_device_eval_batch_size=8, # Batch size per device for evaluation\n", " gradient_accumulation_steps=1, # Number of steps to accumulate gradients before updating\n", " # warmup_steps=10, # Optional: number of warmup steps (uncomment if needed)\n", " learning_rate=lr, # Learning rate for the optimizer\n", " weight_decay=0.01, # Weight decay to apply (for regularization)\n", " logging_steps=0.001, # Log training progress every 0.1 steps\n", " output_dir=\"MIDIstral_pixtral\", # Directory where the fine-tuned model will be saved. Make sure it has pixtral in a name\n", " eval_strategy=\"steps\", # Strategy for evaluation: perform evaluation every few steps\n", " eval_steps=0.02, # Perform evaluation every 0.2 steps (relative to total steps)\n", " lr_scheduler_type=schedule, # Set learning rate scheduler type\n", " # save_strategy=\"steps\", # Optional: save model every few steps (commented out)\n", " # save_steps=250, # Optional: how many steps between saves (commented out)\n", " # save_total_limit=1, # Optional: total number of checkpoints to keep (commented out)\n", " bf16=True, # Use bf16 precision for training\n", " remove_unused_columns=False, # Do not remove unused columns from the dataset\n", " report_to=\"tensorboard\", # Report results to TensorBoard for visualization\n", " run_name=run_name, # Set the run name for tracking experiments\n", " logging_dir=f\"./logs/{run_name}\", # Directory for logging\n", " gradient_checkpointing=True, # Enable gradient checkpointing to save VRAM\n", " gradient_checkpointing_kwargs={'use_reentrant': True} # Additional settings for gradient checkpointing\n", ")\n", "\n", "\n", "trainer = Trainer(\n", " model=model, # The model to be trained\n", " args=training_args, # Training arguments defined earlier\n", " data_collator=data_collator, # Data collator to handle batches\n", " train_dataset=train_dataset, # Training dataset\n", " eval_dataset=eval_dataset, # Evaluation dataset for computing loss or metrics\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.save_model('./MIDIstral/')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.push_to_hub(token='your-auth-token-here')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "processor.push_to_hub(\"asigalov61/MIDIstral_pixtral\", token='your-auth-token-here')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import LlavaForConditionalGeneration, AutoProcessor\n", "import torch\n", "\n", "model = LlavaForConditionalGeneration.from_pretrained(\n", " 'asigalov61/MIDIstral_pixtral',\n", " torch_dtype=torch.bfloat16, # Adjust dtype if needed\n", " device_map='auto'\n", ")\n", "processor = AutoProcessor.from_pretrained('asigalov61/MIDIstral_pixtral')\n", "tokenizer = processor.tokenizer\n", "tokenizer.padding_side = \"left\" # For Flash Attention compatibility\n", "\n", "print(\"Model and processor loaded successfully from checkpoint-30.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "eval_results_after_fine_tuning = run_model_evaluation(model, eval_dataset, num_samples=5, device='cuda', constant_query='Please write the most appropriate lyrics for the song. Thank you.')\n", "\n", "print('eval_results_before_fine_tuning:', eval_results_before_fine_tuning)\n", "print('eval_results_after_fine_tuning:', eval_results_after_fine_tuning)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "eval_dataset[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open('eval_results.txt', 'w') as f:\n", " f.write('eval_results_before_fine_tuning: ' + str(eval_results_before_fine_tuning) + '\\n')\n", " f.write('eval_results_after_fine_tuning: ' + str(eval_results_after_fine_tuning) + '\\n')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 4 }