""" finetune Phi-4-multimodal-instruct on an speech task scipy==1.15.1 peft==0.13.2 backoff==2.2.1 transformers==4.46.1 accelerate==1.3.0 """ import argparse import json import os from pathlib import Path import torch import sacrebleu from accelerate import Accelerator from accelerate.utils import gather_object from datasets import load_dataset from torch.utils.data import Dataset from tqdm import tqdm from transformers import ( AutoModelForCausalLM, AutoProcessor, BatchFeature, Trainer, TrainingArguments, StoppingCriteria, StoppingCriteriaList, ) INSTSRUCTION = { "en_zh-CN": "Translate the audio to Mandarin.", "en_id": "Translate the audio to Indonesian.", "en_sl": "Translate the audio to Slovenian.", } TOKENIZER = { "en_zh-CN": "zh", "en_ja": "ja-mecab", } ANSWER_SUFFIX = "<|end|><|endoftext|>" _IGNORE_INDEX = -100 _TRAIN_SIZE = 50000 _EVAL_SIZE = 200 class MultipleTokenBatchStoppingCriteria(StoppingCriteria): """Stopping criteria capable of receiving multiple stop-tokens and handling batched inputs.""" def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None: """Initialize the multiple token batch stopping criteria. Args: stop_tokens: Stop-tokens. batch_size: Batch size. """ self.stop_tokens = stop_tokens self.max_stop_tokens = stop_tokens.shape[-1] self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # Only gather the maximum number of inputs compatible with stop tokens # and checks whether generated inputs are equal to `stop_tokens` generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens) equal_generated_inputs = torch.all(generated_inputs, dim=2) # Mark the position where a stop token has been produced for each input in the batch, # but only if the corresponding entry is not already set sequence_idx = torch.any(equal_generated_inputs, dim=1) sequence_set_mask = self.stop_tokens_idx == 0 self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1] return torch.all(self.stop_tokens_idx) class CoVoSTDataset(Dataset): def __init__(self, processor, data_dir, split, lang="en_zh-CN", rank=0, world_size=1): self.data = load_dataset("facebook/covost2", lang, data_dir=data_dir, split=split, trust_remote_code=True ) self.training = "train" in split self.processor = processor self.instruction = INSTSRUCTION[lang] if world_size > 1: self.data = self.data.shard(world_size, rank) def __len__(self): return len(self.data) def __getitem__(self, idx): """ {'client_id': '0013037a1d45cc33460806cc3f8ecee9d536c45639ba4cbbf1564f1c051f53ff3c9f89ef2f1bf04badf55b3a2e7654c086f903681a7b6299616cff6f67598eff', 'file': '{data_dir}/clips/common_voice_en_699711.mp3', 'audio': {'path': '{data_dir}/clips/common_voice_en_699711.mp3', 'array': array([-1.28056854e-09, -1.74622983e-09, -1.16415322e-10, ..., 3.92560651e-10, 6.62794264e-10, -3.89536581e-09]), 'sampling_rate': 16000}, 'sentence': '"She\'ll be all right."', 'translation': '她会没事的。', 'id': 'common_voice_en_699711'} """ data = self.data[idx] user_message = { 'role': 'user', 'content': '<|audio_1|>\n' + self.instruction, } prompt = self.processor.tokenizer.apply_chat_template( [user_message], tokenize=False, add_generation_prompt=True ) inputs = self.processor(text=prompt, audios=[(data["audio"]["array"], data["audio"]["sampling_rate"])], return_tensors='pt') answer = f"{data['translation']}{ANSWER_SUFFIX}" answer_ids = self.processor.tokenizer(answer, return_tensors='pt').input_ids if self.training: input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) labels = torch.full_like(input_ids, _IGNORE_INDEX) labels[:, -answer_ids.shape[1] :] = answer_ids else: input_ids = inputs.input_ids labels = answer_ids return { 'input_ids': input_ids, 'labels': labels, 'input_audio_embeds': inputs.input_audio_embeds, 'audio_embed_sizes': inputs.audio_embed_sizes, } def pad_sequence(sequences, padding_side='right', padding_value=0): """ Pad a list of sequences to the same length. sequences: list of tensors in [seq_len, *] shape """ assert padding_side in ['right', 'left'] max_size = sequences[0].size() trailing_dims = max_size[1:] max_len = max(len(seq) for seq in sequences) batch_size = len(sequences) output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) for i, seq in enumerate(sequences): length = seq.size(0) if padding_side == 'right': output.data[i, :length] = seq else: output.data[i, -length:] = seq return output def cat_with_pad(tensors, dim, padding_value=0): """ cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() assert all( t.dim() == ndim for t in tensors[1:] ), 'All tensors must have the same number of dimensions' out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) output = tensors[0].new_full(out_size, padding_value) index = 0 for t in tensors: # Create a slice list where every dimension except dim is full slice slices = [slice(0, t.shape[d]) for d in range(ndim)] # Update only the concat dimension slice slices[dim] = slice(index, index + t.shape[dim]) output[slices] = t index += t.shape[dim] return output def covost_collate_fn(batch): input_ids_list = [] labels_list = [] input_audio_embeds_list = [] audio_embed_sizes_list = [] audio_attention_mask_list = [] for inputs in batch: input_ids_list.append(inputs['input_ids'][0]) labels_list.append(inputs['labels'][0]) input_audio_embeds_list.append(inputs['input_audio_embeds']) audio_embed_sizes_list.append(inputs['audio_embed_sizes']) audio_attention_mask_list.append( inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) ) try: input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) labels = pad_sequence(labels_list, padding_side='left', padding_value=0) audio_attention_mask = ( pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False) if len(audio_attention_mask_list) > 1 else None ) except Exception as e: print(e) print(input_ids_list) print(labels_list) raise attention_mask = (input_ids != 0).long() input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) audio_embed_sizes = torch.cat(audio_embed_sizes_list) return BatchFeature( { 'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask, 'input_audio_embeds': input_audio_embeds, 'audio_embed_sizes': audio_embed_sizes, 'audio_attention_mask': audio_attention_mask, 'input_mode': 2, # speech mode } ) def create_model(model_name_or_path, use_flash_attention=False): model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32, _attn_implementation='flash_attention_2' if use_flash_attention else 'sdpa', trust_remote_code=True, ).to('cuda') return model @torch.no_grad() def evaluate( model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1 ): rank = int(os.environ.get('RANK', 0)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) model.eval() all_generated_texts = [] all_labels = [] eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, collate_fn=covost_collate_fn, shuffle=False, drop_last=False, num_workers=8, prefetch_factor=2, pin_memory=True, ) stop_tokens = ["<|end|>", processor.tokenizer.eos_token] stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"] stop_tokens_ids = stop_tokens_ids.to(f'cuda:{local_rank}') for inputs in tqdm( eval_dataloader, disable=(rank != 0) or disable_tqdm, desc='running eval' ): stopping_criteria=StoppingCriteriaList([MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=inputs.input_ids.size(0))]) inputs = inputs.to(f'cuda:{local_rank}') generated_ids = model.generate( **inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64, stopping_criteria=stopping_criteria, ) stop_tokens_idx = stopping_criteria[0].stop_tokens_idx.reshape(inputs.input_ids.size(0), -1)[:, 0] stop_tokens_idx = torch.where( stop_tokens_idx > 0, stop_tokens_idx - stop_tokens_ids.shape[-1], generated_ids.shape[-1], ) generated_text = [ processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx], skip_special_tokens=True, clean_up_tokenization_spaces=False) for _pred_ids, _stop_tokens_idx in zip(generated_ids, stop_tokens_idx) ] all_generated_texts.extend(generated_text) labels = [processor.decode(_label_ids[_label_ids != 0]).rstrip(ANSWER_SUFFIX) for _label_ids in inputs["labels"]] all_labels.extend(labels) all_generated_texts = gather_object(all_generated_texts) all_labels = gather_object(all_labels) if rank == 0: assert len(all_generated_texts) == len(all_labels) bleu = sacrebleu.corpus_bleu(all_generated_texts, [all_labels]) print(bleu) if save_path: with open(save_path, 'w') as f: save_dict = { 'all_generated_texts': all_generated_texts, 'all_labels': all_labels, 'score': bleu.score, } json.dump(save_dict, f) return bleu.score return None def main(): parser = argparse.ArgumentParser() parser.add_argument( '--model_name_or_path', type=str, default='microsoft/Phi-4-multimodal-instruct', help='Model name or path to load from', ) parser.add_argument( "--common_voice_dir", type=str, default="CommonVoice/EN", help="Unzipped Common Voice Audio dataset directory, refer to https://commonvoice.mozilla.org/en/datasets, version 4.0", ) parser.add_argument( "--lang", type=str, default="en_sl", help="Language pair for translation.", ) parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention') parser.add_argument('--output_dir', type=str, default='./output/', help='Output directory') parser.add_argument('--batch_size', type=int, default=128, help='Batch size') parser.add_argument( '--batch_size_per_gpu', type=int, default=32, help='Batch size per GPU (adjust this to fit in GPU memory)', ) parser.add_argument( '--num_train_epochs', type=int, default=1, help='Number of training epochs' ) parser.add_argument('--learning_rate', type=float, default=4.0e-5, help='Learning rate') parser.add_argument('--wd', type=float, default=0.01, help='Weight decay') parser.add_argument('--no-tqdm', dest='tqdm', action='store_false', help='Disable tqdm') args = parser.parse_args() accelerator = Accelerator() with accelerator.local_main_process_first(): processor = AutoProcessor.from_pretrained( args.model_name_or_path, trust_remote_code=True, ) model = create_model( args.model_name_or_path, use_flash_attention=args.use_flash_attention, ) model.set_lora_adapter('speech') rank = int(os.environ.get('RANK', 0)) world_size = int(os.environ.get('WORLD_SIZE', 1)) eval_dataset = CoVoSTDataset(processor, data_dir=args.common_voice_dir, split=f'test[:{_EVAL_SIZE}]', lang=args.lang, rank=rank, world_size=world_size) train_dataset = CoVoSTDataset(processor, data_dir=args.common_voice_dir, split=f'train[:{_TRAIN_SIZE}]', lang=args.lang) num_gpus = accelerator.num_processes print(f'training on {num_gpus} GPUs') assert ( args.batch_size % (num_gpus * args.batch_size_per_gpu) == 0 ), 'Batch size must be divisible by the number of GPUs' gradient_accumulation_steps = args.batch_size // (num_gpus * args.batch_size_per_gpu) if args.use_flash_attention: fp16 = False bf16 = True else: fp16 = True bf16 = False # hard coded training args training_args = TrainingArguments( num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.batch_size_per_gpu, gradient_checkpointing=True, gradient_checkpointing_kwargs={'use_reentrant': False}, gradient_accumulation_steps=gradient_accumulation_steps, optim='adamw_torch', adam_beta1=0.9, adam_beta2=0.95, adam_epsilon=1e-7, learning_rate=args.learning_rate, weight_decay=args.wd, max_grad_norm=1.0, lr_scheduler_type='linear', warmup_steps=50, logging_steps=10, output_dir=args.output_dir, save_strategy='no', save_total_limit=10, save_only_model=True, bf16=bf16, fp16=fp16, remove_unused_columns=False, report_to='none', deepspeed=None, disable_tqdm=not args.tqdm, dataloader_num_workers=4, ddp_find_unused_parameters=True, # for unused SigLIP layers ) # eval before fine-tuning out_path = Path(training_args.output_dir) out_path.mkdir(parents=True, exist_ok=True) score = evaluate( model, processor, eval_dataset, save_path=out_path / 'eval_before.json', disable_tqdm=not args.tqdm, eval_batch_size=args.batch_size_per_gpu, ) if accelerator.is_main_process: print(f'BLEU Score before finetuning: {score}') trainer = Trainer( model=model, args=training_args, data_collator=covost_collate_fn, train_dataset=train_dataset, ) trainer.train() trainer.save_model() if accelerator.is_main_process: processor.save_pretrained(training_args.output_dir) accelerator.wait_for_everyone() # eval after fine-tuning (load saved checkpoint) # first try to clear GPU memory del model del trainer __import__('gc').collect() torch.cuda.empty_cache() # reload the model for inference model = AutoModelForCausalLM.from_pretrained( training_args.output_dir, torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32, trust_remote_code=True, _attn_implementation='flash_attention_2' if args.use_flash_attention else 'sdpa', ).to('cuda') score = evaluate( model, processor, eval_dataset, save_path=out_path / 'eval_after.json', disable_tqdm=not args.tqdm, eval_batch_size=args.batch_size_per_gpu, ) if accelerator.is_main_process: print(f'BLEU Score after finetuning: {score}') if __name__ == '__main__': main()