***

# MIDIstral Pixtral 12B Fine-Tuning Code

***

## Based upon fine-tuning code by Tomasz Stankiewicz

## https://github.com/tomstaan/Clarivex-Pixtral-12B

***

### Project Los Angeles
### Tegridy Code 2024

***

# Setup

In [None]:
!python3 -m pip install --upgrade pip -q
!pip3 install -U transformers
!pip3 install -q accelerate datasets peft bitsandbytes hf_transfer flash_attn tensorboard
!pip3 install ipywidgets
!pip3 install --upgrade jinja2
!pip3 install --upgrade peft
!pip3 install -U pillow
!pip3 install pip install tf-keras

# Can be a good idea to re-start the kernel after this

In [None]:
!sudo pip3 install tf-keras

In [None]:
!sudo pip install -U numpy==1.26.1

In [None]:
# Enable fast weights download and upload
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Download model

In [None]:
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig

model_id = "mistral-community/pixtral-12b"

model = LlavaForConditionalGeneration.from_pretrained(
 model_id,
 torch_dtype=torch.bfloat16,
 device_map='auto',
 #attn_implementation="sdpa",
)

processor = AutoProcessor.from_pretrained(model_id)

# Extract the tokenizer from the processor
tokenizer = processor.tokenizer

# Set the padding side to 'left' for Flash Attention compatibility
tokenizer.padding_side = "left"

# Chat Template

In [None]:
CHAT_TEMPLATE = """
{%- for message in messages %} 
 {%- if message.role == "user" %} 
 [INST] 
 {%- for item in message.content %} 
 {%- if item.type == "text" %} 
 {{ item.text }} 
 {%- elif item.type == "image" %} 
 \n[IMG] 
 {%- endif %} 
 {%- endfor %} 
 [/INST] 
 {%- elif message.role == "assistant" %} 
 {%- for item in message.content %} 
 {%- if item.type == "text" %} 
 {{ item.text }} 
 {%- endif %} 
 {%- endfor %} 
 
 {%- endif %} 
{%- endfor %} 
"""

# Set the chat template for the tokenizer
processor.chat_template = CHAT_TEMPLATE.replace(' ', '')

processor.tokenizer.pad_token = processor.tokenizer.eos_token

In [None]:
# Example conversation input with user and assistant roles
messages = [
 {
 "role": "user",
 "content": [
 {"type": "text", "text": "Please describe the song music in detail. Thank you."},
 {"type": "image"}
 ]
 },
 {
 "role": "assistant",
 "content": [
 {"type": "text", "text": "The song 'Man In Black' by Johnny Cash in key A# has fast tempo and average pace with Acoustic Guitar(steel) lead, accompanying Acoustic Grand and predominant Acoustic Snare drums"}
 ]
 }
]

# Apply the chat template to format the messages
formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)

# Output the formatted text
print("Formatted text:\n", formatted_text)

# Download dataset

In [None]:
from PIL import Image
import io
from datasets import load_dataset

def deserialize_image(byte_data):
 img_byte_arr = io.BytesIO(byte_data)
 img = Image.open(img_byte_arr)
 return img

dataset = load_dataset("asigalov61/MIDIstral", split='train').train_test_split(test_size=0.001)

# Access the training and test sets
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

In [None]:
len(train_dataset)

In [None]:
eval_dataset[0]

# Evaluation before fine-tuning

In [None]:
import torch
from PIL import Image
from torchvision.transforms.functional import to_pil_image, resize

def run_model_evaluation(model, dataset, num_samples=None, device='cuda', constant_query=None):
 model.eval()
 results = []

 # Limit the dataset if a specific number of samples is provided
 if num_samples is not None:
 dataset = torch.utils.data.Subset(dataset, range(num_samples))

 for example in dataset:
 image = deserialize_image(example["image"])
 if constant_query is None:
 query = example["query"]["en"]
 else:
 query = constant_query # Use the constant query if provided
 
 # Display a reduced size version of the image
 pil_image = image
 aspect_ratio = pil_image.width / pil_image.height
 new_width = 300
 new_height = int(new_width / aspect_ratio)
 display_image = resize(pil_image, (new_height, new_width))
 display_image.show() # This will open the image in the default image viewer

 # Construct the message template
 messages = [
 {
 "role": "user",
 "content": [
 # {"type": "text", "text": "Answer briefly."},
 {"type": "text", "text": query},
 {"type": "image"}, # YOU CAN COMMENT THIS OUT IF THERE ARE NO IMAGES
 # {"type": "image"}, # ADD A SECOND IMAGE!!! Note that the text is also possible here.
 ]
 }
 ]

 # Apply the chat template to preprocess input
 formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 print(f"Formatted prompt: {formatted_prompt}")
 text = processor.apply_chat_template(messages, add_generation_prompt=True)
 inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True).to(device)
 # inputs = processor(text=[text.strip()], images=[image, image2], return_tensors="pt" padding=True).to(device)

 # Generate output from the model
 generated_ids = model.generate(**inputs, max_new_tokens=64)
 generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].shape[-1]:])

 print(f"Prediction: {generated_texts[0]}\n")

 results.append(generated_texts[0]) # Store the result

 return results




In [None]:
# Usage
eval_results_before_fine_tuning = run_model_evaluation(model, 
 eval_dataset, 
 num_samples=2, 
 device='cuda', 
 constant_query='Please describe the song music in detail. Thank you.')

print('eval_results_before_fine_tuning:', eval_results_before_fine_tuning)

# Fine-tuning

In [None]:
import torch

class MyDataCollator:
 def __init__(self, processor):
 self.processor = processor

 def __call__(self, examples):
 texts = []
 images = []
 assistant_responses = [] # To track assistant responses for proper masking
 for example in examples:
 image = deserialize_image(example["image"])
 question = example["question"] # for chess dataset
 answer = example["answer"] # for chess dataset

 messages = [
 {
 "role": "user",
 "content": [
 {"type": "text", "text": question},
 {"type": "image"}, # Images after the text.
 ]
 },
 {
 "role": "assistant",
 "content": [
 {"type": "text", "text": answer}
 ]
 }
 ]

 # Convert messages to the desired text format using processor's template
 text = self.processor.apply_chat_template(messages, add_generation_prompt=False)

 texts.append(text.strip())
 images.append([image])
 assistant_responses.append(answer) # Track assistant's response for later use

 # Tokenize and process batch
 batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True)

 # Prepare labels; we will mask non-assistant tokens for generation
 labels = batch["input_ids"].clone() 

 # For each example, find assistant tokens and mask everything else
 for i, (input_ids, assistant_response) in enumerate(zip(batch["input_ids"], assistant_responses)):
 # Tokenize just the assistant response
 assistant_tokens = self.processor.tokenizer(assistant_response, return_tensors="pt")["input_ids"][0]

 # Find where the assistant tokens start in the input sequence
 start_idx = self.find_subsequence(input_ids, assistant_tokens)

 if start_idx is not None:
 # Mask everything except the assistant tokens
 labels[i, :start_idx] = -100 # Ignore everything before the assistant's response
 labels[i, start_idx + len(assistant_tokens):] = -100 # Ignore everything after

 # Assign masked labels back to the batch
 batch["labels"] = labels

 return batch
 
 def find_subsequence(self, sequence, subsequence):
 """
 Find the start index of a subsequence (assistant tokens) in a sequence (input tokens).
 """
 seq_len = len(sequence)
 sub_len = len(subsequence)

 for i in range(seq_len - sub_len + 1):
 if torch.equal(sequence[i:i + sub_len], subsequence):
 return i
 return None
 
data_collator = MyDataCollator(processor)

In [None]:
import torch

# Select a small batch of examples (e.g., 2 examples for quick testing)
sample_batch = [train_dataset[i] for i in range(2)]

# Call the data collator with the sample batch to process it
processed_batch = data_collator(sample_batch)

# Print the processed batch keys to check what's inside
print("Processed batch keys:", processed_batch.keys())

# Print out the texts after applying the chat template
print("\nTokenized input IDs (before padding):")
print(processed_batch["input_ids"])

In [None]:
processed_batch["input_ids"].shape

In [None]:
print(model)

In [None]:
from peft import LoraConfig

lora_config = LoraConfig(
 r=32, # Rank (usually 8, 16, or 32 depending on model size and needs)
 lora_alpha=32, # Scaling factor for the low-rank updates
 use_rslora=True, # Use RS LoRA for regularization
 target_modules="all-linear", # Target specific modules (e.g., linear layers)
 # modules_to_save=['lm_head','embed_tokens'],
 lora_dropout=0.1, # Dropout for low-rank adapter layers
 bias="none", # Bias in adapter layers: "none", "all", "lora_only"
 task_type="CAUSAL_LM" # Task type: "CAUSAL_LM", "SEQ_2_SEQ_LM", or "TOKEN_CLS"
)

In [None]:
from peft import get_peft_model

model=get_peft_model(model, lora_config)

In [None]:
model.print_trainable_parameters()

In [None]:
from transformers import TrainingArguments, Trainer

# for main fine-tuning
epochs = 1
lr = 3e-5
schedule = "constant"

# Optional, for annealing
# epochs = 0.4
# lr = 3e-5
# schedule = "linear"

run_name = f"MIDIstral-{lr}_lr-{epochs}_epochs-{schedule}_schedule"

training_args = TrainingArguments(
 # max_steps=1, # Optional: run only for one step, useful for debugging
 num_train_epochs=epochs, # Number of training epochs
 per_device_train_batch_size=8, # Batch size per device for training
 per_device_eval_batch_size=8, # Batch size per device for evaluation
 gradient_accumulation_steps=1, # Number of steps to accumulate gradients before updating
 # warmup_steps=10, # Optional: number of warmup steps (uncomment if needed)
 learning_rate=lr, # Learning rate for the optimizer
 weight_decay=0.01, # Weight decay to apply (for regularization)
 logging_steps=0.001, # Log training progress every 0.1 steps
 output_dir="MIDIstral_pixtral", # Directory where the fine-tuned model will be saved. Make sure it has pixtral in a name
 eval_strategy="steps", # Strategy for evaluation: perform evaluation every few steps
 eval_steps=0.02, # Perform evaluation every 0.2 steps (relative to total steps)
 lr_scheduler_type=schedule, # Set learning rate scheduler type
 # save_strategy="steps", # Optional: save model every few steps (commented out)
 # save_steps=250, # Optional: how many steps between saves (commented out)
 # save_total_limit=1, # Optional: total number of checkpoints to keep (commented out)
 bf16=True, # Use bf16 precision for training
 remove_unused_columns=False, # Do not remove unused columns from the dataset
 report_to="tensorboard", # Report results to TensorBoard for visualization
 run_name=run_name, # Set the run name for tracking experiments
 logging_dir=f"./logs/{run_name}", # Directory for logging
 gradient_checkpointing=True, # Enable gradient checkpointing to save VRAM
 gradient_checkpointing_kwargs={'use_reentrant': True} # Additional settings for gradient checkpointing
)


trainer = Trainer(
 model=model, # The model to be trained
 args=training_args, # Training arguments defined earlier
 data_collator=data_collator, # Data collator to handle batches
 train_dataset=train_dataset, # Training dataset
 eval_dataset=eval_dataset, # Evaluation dataset for computing loss or metrics
)

In [None]:
trainer.train()

In [None]:
trainer.save_model('./MIDIstral/')

In [None]:
trainer.push_to_hub(token='your-auth-token-here')

In [None]:
processor.push_to_hub("asigalov61/MIDIstral_pixtral", token='your-auth-token-here')

# Inference

In [None]:
from transformers import LlavaForConditionalGeneration, AutoProcessor
import torch

model = LlavaForConditionalGeneration.from_pretrained(
 'asigalov61/MIDIstral_pixtral',
 torch_dtype=torch.bfloat16, # Adjust dtype if needed
 device_map='auto'
)
processor = AutoProcessor.from_pretrained('asigalov61/MIDIstral_pixtral')
tokenizer = processor.tokenizer
tokenizer.padding_side = "left" # For Flash Attention compatibility

print("Model and processor loaded successfully from checkpoint-30.")

Evaluation

In [None]:
eval_results_after_fine_tuning = run_model_evaluation(model, eval_dataset, num_samples=5, device='cuda', constant_query='Please write the most appropriate lyrics for the song. Thank you.')

print('eval_results_before_fine_tuning:', eval_results_before_fine_tuning)
print('eval_results_after_fine_tuning:', eval_results_after_fine_tuning)

In [None]:
eval_dataset[0]

In [None]:
with open('eval_results.txt', 'w') as f:
 f.write('eval_results_before_fine_tuning: ' + str(eval_results_before_fine_tuning) + '\n')
 f.write('eval_results_after_fine_tuning: ' + str(eval_results_after_fine_tuning) + '\n')