Spaces:
Running
on
A100
Running
on
A100
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # This file is modified from https://github.com/haotian-liu/LLaVA/ | |
| import json | |
| import os | |
| import random | |
| import time | |
| from typing import Dict, List, Optional | |
| import torch | |
| import torch.distributed as dist | |
| from torch import nn | |
| from torch.utils.data import ConcatDataset, Dataset, DistributedSampler, RandomSampler, Sampler | |
| from transformers import PreTrainedModel, Trainer | |
| from transformers.modeling_utils import unwrap_model | |
| from transformers.trainer import ALL_LAYERNORM_LAYERS # ShardedDDPOption, | |
| from transformers.trainer import get_parameter_names, has_length, is_sagemaker_mp_enabled, logger | |
| from llava.train.sequence_parallel import get_pg_manager | |
| from llava.trl.trainer import DPOTrainer | |
| import numpy as np | |
| def maybe_zero_3(param, ignore_status=False, name=None): | |
| from deepspeed import zero | |
| from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus | |
| if hasattr(param, "ds_id"): | |
| if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: | |
| if not ignore_status: | |
| print(name, "no ignore status") | |
| with zero.GatheredParameters([param]): | |
| param = param.data.detach().cpu().clone() | |
| else: | |
| param = param.detach().cpu().clone() | |
| return param | |
| def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): | |
| to_return = {k: t for k, t in named_params if "lora_" not in k} | |
| if require_grad_only: | |
| to_return = {k: t for k, t in to_return.items() if t.requires_grad} | |
| to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} | |
| return to_return | |
| def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): | |
| to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} | |
| to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} | |
| return to_return | |
| def split_to_even_chunks(indices, lengths, num_chunks): | |
| """ | |
| Split a list of indices into `chunks` chunks of roughly equal lengths. | |
| """ | |
| if len(indices) % num_chunks != 0: | |
| return [indices[i::num_chunks] for i in range(num_chunks)] | |
| num_indices_per_chunk = len(indices) // num_chunks | |
| chunks = [[] for _ in range(num_chunks)] | |
| chunks_lengths = [0 for _ in range(num_chunks)] | |
| for index in indices: | |
| shortest_chunk = chunks_lengths.index(min(chunks_lengths)) | |
| chunks[shortest_chunk].append(index) | |
| chunks_lengths[shortest_chunk] += lengths[index] | |
| if len(chunks[shortest_chunk]) == num_indices_per_chunk: | |
| chunks_lengths[shortest_chunk] = float("inf") | |
| return chunks | |
| def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): | |
| # We need to use torch for the random part as a distributed sampler will set the random seed for torch. | |
| assert all(l != 0 for l in lengths), "Should not have zero length." | |
| if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): | |
| # all samples are in the same modality | |
| return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) | |
| mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) | |
| lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) | |
| mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] | |
| lang_shuffle = [ | |
| lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None) | |
| ] | |
| megabatch_size = world_size * batch_size | |
| mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] | |
| lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] | |
| last_mm = mm_megabatches[-1] | |
| last_lang = lang_megabatches[-1] | |
| additional_batch = last_mm + last_lang | |
| megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] | |
| megabatch_indices = torch.randperm(len(megabatches), generator=generator) | |
| megabatches = [megabatches[i] for i in megabatch_indices] | |
| if len(additional_batch) > 0: | |
| megabatches.append(sorted(additional_batch)) | |
| return [i for megabatch in megabatches for i in megabatch] | |
| def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): | |
| # We need to use torch for the random part as a distributed sampler will set the random seed for torch. | |
| indices = torch.randperm(len(lengths), generator=generator) | |
| megabatch_size = world_size * batch_size | |
| megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] | |
| megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] | |
| megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] | |
| return [i for megabatch in megabatches for batch in megabatch for i in batch] | |
| class VILADistributedSampler(DistributedSampler): | |
| """This class is implemented by Jason Lu.""" | |
| def __init__( | |
| self, | |
| dataset, | |
| num_replicas: Optional[int] = None, | |
| rank: Optional[int] = None, | |
| shuffle: bool = True, | |
| seed: int = 0, | |
| drop_last: bool = False, | |
| batch_size=None, | |
| # NOTE: this is the total size but not per-worker | |
| sample_len_list=None, | |
| force_accumulation=True, | |
| sp_degree: int = 1, | |
| gradient_accumulation_steps: int = 1, | |
| ) -> None: | |
| if num_replicas is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| num_replicas = dist.get_world_size() | |
| if rank is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| rank = dist.get_rank() | |
| if rank >= num_replicas or rank < 0: | |
| raise ValueError( | |
| "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1) | |
| ) | |
| self.dataset = dataset | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.epoch = 0 | |
| self.drop_last = True # always True | |
| self.sp_degree = max(1, sp_degree) | |
| self.bs_divisible_by_sp = batch_size % self.sp_degree == 0 | |
| # Consider sequence parallelism | |
| if self.sp_degree > 1: # Sequence Parallelism is enabled | |
| PROCESS_GROUP_MANAGER = get_pg_manager() | |
| self.dp_rank = PROCESS_GROUP_MANAGER.dp_rank | |
| self.dp_num_replicas = num_replicas // sp_degree | |
| self.corresponding_ranks = list(range(self.dp_rank * self.sp_degree, (self.dp_rank + 1) * self.sp_degree)) | |
| else: | |
| self.dp_rank = rank | |
| self.dp_num_replicas = num_replicas | |
| self.batch_size = batch_size | |
| self.global_batch_size = batch_size * self.dp_num_replicas | |
| # NOTE: org_ is without drop last | |
| self.org_sample_len_list = self.per_replica_samples = sample_len_list | |
| assert sum(sample_len_list) == len(self.dataset) | |
| if self.drop_last: # type: ignore[arg-type] | |
| self.per_replica_samples = [ | |
| sample_len | |
| // (self.num_replicas * self.batch_size * gradient_accumulation_steps // self.sp_degree) | |
| * self.batch_size | |
| * gradient_accumulation_steps | |
| // self.sp_degree | |
| for sample_len in self.per_replica_samples | |
| ] | |
| self.num_samples = sum(self.per_replica_samples) | |
| else: | |
| raise NotImplementedError | |
| self.total_size = self.num_samples * self.num_replicas | |
| self.total_samples = [samples * self.num_replicas for samples in self.per_replica_samples] | |
| self.shuffle = shuffle | |
| self.seed = seed | |
| # whether to force accumulate | |
| self.force_accumulation = force_accumulation | |
| def __len__(self) -> int: | |
| return self.num_samples * self.sp_degree | |
| def __iter__(self): | |
| indices = list(range(len(self.dataset))) | |
| # 1. split the full indices first (note: without drop last at this moment) | |
| indices_list = [] | |
| for i in range(len(self.org_sample_len_list)): | |
| indices_list.append( | |
| indices[sum(self.org_sample_len_list[:i]) : sum(self.org_sample_len_list[:i]) + self.total_samples[i]] | |
| ) | |
| assert sum([len(indices) for indices in indices_list]) == self.total_size, ( | |
| sum([len(indices) for indices in indices_list]), | |
| self.total_size, | |
| ) | |
| if ( | |
| self.sp_degree > 1 and self.bs_divisible_by_sp | |
| ): # Sequence Parallelism is enabled, to ensure the same behavior as data parallelism | |
| dp_indices_dict = {} # {rank: indices_list} | |
| all_indices_dict = {} # {rank: all_indices} | |
| for i in self.corresponding_ranks: | |
| dp_indices_list = [] | |
| for idx, indices in enumerate(indices_list): | |
| dp_indices_list.append( | |
| indices[i * self.per_replica_samples[idx] : (i + 1) * self.per_replica_samples[idx]] | |
| ) | |
| random.seed(self.seed + self.epoch) | |
| for indice in range(len(dp_indices_list)): | |
| random.shuffle(dp_indices_list[indice]) | |
| dp_indices_dict[i] = dp_indices_list.copy() | |
| for rank, dp_indices_list in dp_indices_dict.items(): | |
| dp_indices_list = sorted(dp_indices_list, key=lambda x: -len(x)) | |
| dp_all_indices = [-1] * self.num_samples | |
| indices_available = list(range(self.num_samples)) | |
| for indice in dp_indices_list: | |
| original_indices = range(len(indice)) | |
| transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices] | |
| mapped_indices = [indices_available[idx] for idx in transformed_indices] | |
| # update indices_available | |
| for idx in reversed(transformed_indices): | |
| del indices_available[idx] | |
| for i, idx in enumerate(mapped_indices): | |
| dp_all_indices[idx] = indice[i] | |
| all_indices_dict[rank] = dp_all_indices | |
| # Interleaving Merge | |
| merged_indices = [] | |
| interleaved_indices = [] | |
| for item_idx in range(len(all_indices_dict[self.corresponding_ranks[0]])): | |
| for rank in self.corresponding_ranks: | |
| interleaved_indices.append(all_indices_dict[rank][item_idx]) | |
| merged_indices.append(interleaved_indices) | |
| all_indices = merged_indices[0] | |
| else: | |
| # let's first do subsample | |
| for idx, indices in enumerate(indices_list): | |
| indices_list[idx] = indices[ | |
| self.rank * self.per_replica_samples[idx] : (self.rank + 1) * self.per_replica_samples[idx] | |
| ] | |
| random.seed(self.seed + self.epoch) | |
| for indice in range(len(indices_list)): | |
| random.shuffle(indices_list[indice]) | |
| indices_list = sorted(indices_list, key=lambda x: -len(x)) | |
| all_indices = [-1] * self.num_samples | |
| indices_available = list(range(self.num_samples)) | |
| for indice in indices_list: | |
| original_indices = range(len(indice)) | |
| transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices] | |
| mapped_indices = [indices_available[idx] for idx in transformed_indices] | |
| # update indices_available | |
| for idx in reversed(transformed_indices): | |
| del indices_available[idx] | |
| for i, idx in enumerate(mapped_indices): | |
| all_indices[idx] = indice[i] | |
| assert -1 not in all_indices | |
| return iter(all_indices) | |
| class LongVILADistributedSampler(VILADistributedSampler): | |
| """This class is implemented by Yukang Chen.""" | |
| def __iter__(self): | |
| def batch_shuffle(indices): | |
| batch_indices = list(range(indices[0] // self.batch_size, indices[-1] // self.batch_size + 1)) | |
| random.shuffle(batch_indices) | |
| indices_shuffled = [ | |
| batch_indices[i // self.batch_size] * self.batch_size + index % self.batch_size | |
| for i, index in enumerate(indices) | |
| ] | |
| return indices_shuffled | |
| indices = list(range(len(self.dataset))) | |
| # 1. split the full indices first (note: without drop last at this moment) | |
| indices_list = [] | |
| for i in range(len(self.org_sample_len_list)): | |
| indices_list.append( | |
| indices[sum(self.org_sample_len_list[:i]) : sum(self.org_sample_len_list[:i]) + self.total_samples[i]] | |
| ) | |
| assert sum([len(indices) for indices in indices_list]) == self.total_size, ( | |
| sum([len(indices) for indices in indices_list]), | |
| self.total_size, | |
| ) | |
| if self.sp_degree > 1: # Sequence Parallelism is enabled, to ensure the same behavior as data parallelism | |
| dp_indices_dict = {} # {rank: indices_list} | |
| all_indices_dict = {} # {rank: all_indices} | |
| for i in self.corresponding_ranks: | |
| dp_indices_list = [] | |
| for idx, indices in enumerate(indices_list): | |
| dp_indices_list.append( | |
| indices[i * self.per_replica_samples[idx] : (i + 1) * self.per_replica_samples[idx]] | |
| ) | |
| random.seed(self.seed + self.epoch) | |
| for indice in range(len(dp_indices_list)): | |
| batch_shuffle(dp_indices_list[indice]) | |
| dp_indices_dict[i] = dp_indices_list.copy() | |
| for rank, dp_indices_list in dp_indices_dict.items(): | |
| dp_indices_list = sorted(dp_indices_list, key=lambda x: -len(x)) | |
| dp_all_indices = [-1] * self.num_samples | |
| indices_available = list(range(self.num_samples)) | |
| for indice in dp_indices_list: | |
| original_indices = range(len(indice)) | |
| transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices] | |
| mapped_indices = [indices_available[idx] for idx in transformed_indices] | |
| # update indices_available | |
| for idx in reversed(transformed_indices): | |
| del indices_available[idx] | |
| for i, idx in enumerate(mapped_indices): | |
| dp_all_indices[idx] = indice[i] | |
| all_indices_dict[rank] = dp_all_indices | |
| # Interleaving Merge | |
| merged_indices = [] | |
| interleaved_indices = [] | |
| for item_idx in range(len(all_indices_dict[self.corresponding_ranks[0]])): | |
| for rank in self.corresponding_ranks: | |
| interleaved_indices.append(all_indices_dict[rank][item_idx]) | |
| merged_indices.append(interleaved_indices) | |
| all_indices = merged_indices[0] | |
| else: | |
| # let's first do subsample | |
| for idx, indices in enumerate(indices_list): | |
| indices_list[idx] = indices[ | |
| self.rank * self.per_replica_samples[idx] : (self.rank + 1) * self.per_replica_samples[idx] | |
| ] | |
| random.seed(self.seed + self.epoch) | |
| for indice in range(len(indices_list)): | |
| batch_shuffle(indices_list[indice]) | |
| indices_list = sorted(indices_list, key=lambda x: -len(x)) | |
| all_indices = [-1] * self.num_samples | |
| indices_available = list(range(self.num_samples)) | |
| for indice in indices_list: | |
| original_indices = range(len(indice)) | |
| transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices] | |
| mapped_indices = [indices_available[idx] for idx in transformed_indices] | |
| # update indices_available | |
| for idx in reversed(transformed_indices): | |
| del indices_available[idx] | |
| for i, idx in enumerate(mapped_indices): | |
| all_indices[idx] = indice[i] | |
| assert -1 not in all_indices | |
| return iter(all_indices) | |
| def get_length_grouped_batches( | |
| lengths: List[int], | |
| batch_size: int, | |
| world_size: int, | |
| generator=None, | |
| merge: bool = True, | |
| ) -> List: | |
| N = len(lengths) | |
| M = world_size * batch_size | |
| if N < M: | |
| # fallback: just random permute everything | |
| idx = np.arange(N) | |
| if generator is not None: | |
| seed = generator.initial_seed() | |
| rng = np.random.RandomState(seed) | |
| else: | |
| rng = np.random.RandomState() | |
| rng.shuffle(idx) | |
| if merge: | |
| return idx.tolist() | |
| else: | |
| # one megabatch only | |
| out = [idx.tolist()] | |
| # pad to world_size empty lists if needed | |
| return [out + [[]] * (world_size - 1)] | |
| # 1) build RNG | |
| if generator is not None: | |
| seed = generator.initial_seed() | |
| rng = np.random.RandomState(seed) | |
| else: | |
| rng = np.random.RandomState() | |
| # 2) keys for lexsort: primary = -length, secondary = random | |
| lengths_arr = np.array(lengths, dtype=np.int64) | |
| key_length = -lengths_arr | |
| key_rand = rng.permutation(N) | |
| # 3) single global lexsort (last key is primary) | |
| sorted_idx = np.lexsort((key_rand, key_length)) | |
| # 4) trim to full megabatches | |
| num_mb = len(sorted_idx) // M | |
| trimmed = sorted_idx[: num_mb * M] | |
| # 5) reshape to [num_mb, M] | |
| mb = trimmed.reshape(num_mb, M) | |
| # 6) optional shuffle of whole megabatches | |
| rng.shuffle(mb) | |
| # 7) split each row into [world_size, batch_size] | |
| mb = mb.reshape(num_mb, world_size, batch_size) | |
| if merge: | |
| # flatten in order megabatch → replica → sample | |
| return mb.reshape(-1).tolist() | |
| else: | |
| # build nested Python lists: [ [ [..], [..], … ], … ] | |
| return [ | |
| [mb[i, r].tolist() for r in range(world_size)] | |
| for i in range(num_mb) | |
| ] | |
| # def get_length_grouped_batches( | |
| # lengths: List[int], | |
| # batch_size: int, | |
| # world_size: int, | |
| # generator=None, | |
| # merge: bool = True, | |
| # ) -> List: | |
| # """ | |
| # Create length-grouped megabatches. | |
| # First, a random permutation of indices is computed. Then we split | |
| # into megabatches of size (world_size * batch_size) and sort each | |
| # megabatch by descending length. Finally, each megabatch is split | |
| # into `world_size` chunks (one per replica). | |
| # If merge is True, a flat list is returned; if False, the nested | |
| # structure is kept. | |
| # """ | |
| # indices = torch.randperm(len(lengths), generator=generator) | |
| # megabatch_size = world_size * batch_size | |
| # # Partition indices into megabatches | |
| # megabatches = [ | |
| # indices[i : i + megabatch_size].tolist() | |
| # for i in range(0, len(lengths), megabatch_size) | |
| # ] | |
| # # Within each megabatch, sort indices in descending order of length. | |
| # sorted_megabatches = [ | |
| # sorted(megabatch, key=lambda i: lengths[i], reverse=True) | |
| # for megabatch in megabatches | |
| # ] | |
| # # Split each sorted megabatch evenly among replicas. | |
| # split_megabatches = [ | |
| # split_to_even_chunks(megabatch, lengths, world_size) | |
| # for megabatch in sorted_megabatches | |
| # ] | |
| # if merge: | |
| # # Flatten into a single list. | |
| # return [i for megabatch in split_megabatches for batch in megabatch for i in batch] | |
| # else: | |
| # # Return the nested structure: list of megabatches, each containing a list (of length world_size) of batches. | |
| # return split_megabatches | |
| class LengthGroupedVILADistributedSampler(DistributedSampler): | |
| """ | |
| A sampler that groups examples by (approximate) length and then | |
| distributes them across replicas following VILA’s accumulation logic. | |
| Parameters: | |
| - dataset: the dataset to sample from. | |
| - batch_size: batch size per replica. | |
| - lengths: a list of lengths (one per example in the dataset). | |
| - num_replicas: total number of distributed replicas (if not provided, | |
| will be inferred from torch.distributed). | |
| - rank: the rank of the current process. | |
| - shuffle: whether to shuffle groups. | |
| - seed: base random seed. | |
| - drop_last: whether to drop the tail of incomplete megabatches (set True). | |
| - sp_degree: sequence-parallel degree. | |
| - gradient_accumulation_steps: used for scaling the effective batch size. | |
| - group_by_modality: if True, you might call a different grouping function. | |
| - generator: optional torch.Generator for determinism. | |
| - force_accumulation: whether to force the VILA accumulation ordering. | |
| """ | |
| def __init__( | |
| self, | |
| dataset, | |
| batch_size: int, | |
| lengths: List[int], | |
| num_replicas: Optional[int] = None, | |
| rank: Optional[int] = None, | |
| shuffle: bool = True, | |
| seed: int = 0, | |
| drop_last: bool = True, | |
| sp_degree: int = 1, | |
| gradient_accumulation_steps: int = 1, | |
| group_by_modality: bool = True, | |
| generator=None, | |
| force_accumulation: bool = True, | |
| ): | |
| super().__init__(dataset, num_replicas=num_replicas, rank=rank, | |
| shuffle=shuffle, seed=seed, drop_last=drop_last) | |
| self.dataset = dataset | |
| self.batch_size = batch_size | |
| self.lengths = lengths | |
| self.generator = generator | |
| self.group_by_modality = group_by_modality | |
| self.sp_degree = max(1, sp_degree) | |
| self.gradient_accumulation_steps = gradient_accumulation_steps | |
| self.force_accumulation = force_accumulation | |
| self.seed = seed | |
| self.epoch = 0 # This should be updated externally at each epoch. | |
| self.world_size = self.num_replicas # from DistributedSampler | |
| self.bs_divisible_by_sp = (batch_size % self.sp_degree == 0) | |
| if self.sp_degree > 1: | |
| # Get sequence parallelism group info. | |
| PROCESS_GROUP_MANAGER = get_pg_manager() # Must be implemented. | |
| self.dp_rank = PROCESS_GROUP_MANAGER.dp_rank | |
| self.dp_num_replicas = self.num_replicas // self.sp_degree | |
| self.corresponding_ranks = list(range(self.dp_rank * self.sp_degree, (self.dp_rank + 1) * self.sp_degree)) | |
| else: | |
| self.dp_rank = self.rank | |
| self.dp_num_replicas = self.num_replicas | |
| # Compute the number of full megabatches (each of size world_size * batch_size). | |
| megabatch_size = self.world_size * self.batch_size | |
| num_full_megabatches = len(self.dataset) // megabatch_size | |
| # For each full megabatch, each replica gets batch_size examples. | |
| self.num_samples = num_full_megabatches * self.batch_size | |
| def __len__(self) -> int: | |
| # When using sequence parallelism, the effective number may be scaled. | |
| return self.num_samples * (self.sp_degree if self.sp_degree > 1 else 1) | |
| def __iter__(self): | |
| # Get the nested list of length-grouped batches. | |
| # Each element in "megabatches" is a list of length world_size, one per replica. | |
| megabatches = get_length_grouped_batches( | |
| self.lengths, | |
| self.batch_size, | |
| self.world_size, | |
| generator=self.generator, | |
| merge=False, | |
| ) | |
| # For each megabatch, select the batch corresponding to this replica. | |
| indices_list = [] | |
| for megabatch in megabatches: | |
| if self.rank < len(megabatch): | |
| indices_list.append(megabatch[self.rank]) | |
| total_samples = sum(len(lst) for lst in indices_list) | |
| if self.sp_degree > 1 and self.bs_divisible_by_sp: | |
| # --- Sequence Parallelism branch --- | |
| # For each of the corresponding sequence-parallel ranks, split each batch. | |
| dp_indices_dict = {} | |
| all_indices_dict = {} | |
| for r in self.corresponding_ranks: | |
| dp_indices_list = [] | |
| for lst in indices_list: | |
| # Split each list into sp_degree equal parts. | |
| part_size = len(lst) // self.sp_degree | |
| dp_indices_list.append(lst[r * part_size : (r + 1) * part_size]) | |
| random.seed(self.seed + self.epoch) | |
| for sublist in dp_indices_list: | |
| random.shuffle(sublist) | |
| dp_indices_dict[r] = dp_indices_list.copy() | |
| # Now, for each sequence-parallel rank, remap the indices. | |
| for r, dp_list in dp_indices_dict.items(): | |
| # Sort the sublists by descending length. | |
| dp_list = sorted(dp_list, key=lambda x: -len(x)) | |
| num_samples_r = sum(len(x) for x in dp_list) | |
| dp_all_indices = [-1] * num_samples_r | |
| indices_available = list(range(num_samples_r)) | |
| for sublist in dp_list: | |
| n = len(sublist) | |
| transformed_indices = [i * len(indices_available) // n for i in range(n)] | |
| mapped_indices = [indices_available[j] for j in transformed_indices] | |
| for j in sorted(transformed_indices, reverse=True): | |
| del indices_available[j] | |
| for i, pos in enumerate(mapped_indices): | |
| dp_all_indices[pos] = sublist[i] | |
| all_indices_dict[r] = dp_all_indices | |
| # Interleave the indices from all sequence-parallel ranks. | |
| merged_indices = [] | |
| # Assumes each dp_all_indices list is of the same length. | |
| interleaved_length = len(next(iter(all_indices_dict.values()))) | |
| for i in range(interleaved_length): | |
| for r in self.corresponding_ranks: | |
| merged_indices.append(all_indices_dict[r][i]) | |
| final_indices = merged_indices | |
| else: | |
| # --- Non-sequence-parallel branch --- | |
| random.seed(self.seed + self.epoch) | |
| for sublist in indices_list: | |
| random.shuffle(sublist) | |
| # Sort the groups by descending length. | |
| indices_list = sorted(indices_list, key=lambda x: -len(x)) | |
| dp_all_indices = [-1] * total_samples | |
| indices_available = list(range(total_samples)) | |
| for sublist in indices_list: | |
| n = len(sublist) | |
| transformed_indices = [i * len(indices_available) // n for i in range(n)] | |
| mapped_indices = [indices_available[j] for j in transformed_indices] | |
| for j in sorted(transformed_indices, reverse=True): | |
| del indices_available[j] | |
| for i, pos in enumerate(mapped_indices): | |
| dp_all_indices[pos] = sublist[i] | |
| final_indices = dp_all_indices | |
| assert -1 not in final_indices, "Some indices were not assigned properly." | |
| return iter(final_indices) | |
| class LengthGroupedSampler(Sampler): | |
| r""" | |
| Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while | |
| keeping a bit of randomness. | |
| """ | |
| def __init__( | |
| self, | |
| batch_size: int, | |
| world_size: int, | |
| lengths: Optional[List[int]] = None, | |
| generator=None, | |
| group_by_modality: bool = False, | |
| ): | |
| if lengths is None: | |
| raise ValueError("Lengths must be provided.") | |
| self.batch_size = batch_size | |
| self.world_size = world_size | |
| self.lengths = lengths | |
| self.generator = generator | |
| self.group_by_modality = group_by_modality | |
| def __len__(self): | |
| return len(self.lengths) | |
| def __iter__(self): | |
| if self.group_by_modality: | |
| indices = get_modality_length_grouped_indices( | |
| self.lengths, self.batch_size, self.world_size, generator=self.generator | |
| ) | |
| else: | |
| indices = get_length_grouped_indices( | |
| self.lengths, self.batch_size, self.world_size, generator=self.generator | |
| ) | |
| return iter(indices) | |
| class VILADPOTrainer(DPOTrainer): | |
| def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: | |
| if self.train_dataset is None or not has_length(self.train_dataset): | |
| return None | |
| # Always using Jason's sampler. | |
| sample_len_list = self.args.sample_lens | |
| seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed | |
| num_replicas = self.args.world_size | |
| rank = self.args.process_index | |
| return VILADistributedSampler( | |
| self.train_dataset, | |
| num_replicas=num_replicas, | |
| rank=rank, | |
| seed=seed, | |
| batch_size=self.args.train_batch_size, | |
| sample_len_list=sample_len_list, | |
| sp_degree=self.args.seq_parallel_size, | |
| gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
| ) | |
| def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: | |
| if self.eval_dataset is None or not has_length(self.eval_dataset): | |
| return None | |
| # Always using Jason's sampler. | |
| sample_len_list = self.args.eval_sample_lens | |
| seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed | |
| return VILADistributedSampler( | |
| eval_dataset, | |
| num_replicas=self.args.world_size, | |
| rank=self.args.process_index, | |
| seed=seed, | |
| batch_size=self.args.eval_batch_size, | |
| sample_len_list=sample_len_list, | |
| gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
| ) | |
| def create_optimizer(self): | |
| """ | |
| Setup the optimizer. | |
| We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
| Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
| """ | |
| if is_sagemaker_mp_enabled(): | |
| return super().create_optimizer() | |
| # if self.sharded_ddp == ShardedDDPOption.SIMPLE: | |
| # return super().create_optimizer() | |
| opt_model = self.model | |
| if self.optimizer is None: | |
| decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) | |
| decay_parameters = [name for name in decay_parameters if "bias" not in name] | |
| if self.args.mm_projector_lr is not None: | |
| projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n in decay_parameters and n not in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n in decay_parameters and n in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| "lr": self.args.mm_projector_lr, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and n in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| "lr": self.args.mm_projector_lr, | |
| }, | |
| ] | |
| else: | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | |
| if 0: # self.sharded_ddp == ShardedDDPOption.SIMPLE: | |
| self.optimizer = OSS( | |
| params=optimizer_grouped_parameters, | |
| optim=optimizer_cls, | |
| **optimizer_kwargs, | |
| ) | |
| else: | |
| self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
| if optimizer_cls.__name__ == "Adam8bit": | |
| import bitsandbytes | |
| manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | |
| skipped = 0 | |
| for module in opt_model.modules(): | |
| if isinstance(module, nn.Embedding): | |
| skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) | |
| logger.info(f"skipped {module}: {skipped/2**20}M params") | |
| manager.register_module_override(module, "weight", {"optim_bits": 32}) | |
| logger.debug(f"bitsandbytes: will optimize {module} in fp32") | |
| logger.info(f"skipped: {skipped/2**20}M params") | |
| return self.optimizer | |
| def save_model(self, output_dir: Optional[str], _internal_call: bool): | |
| ## save tuned model separately | |
| if self.is_deepspeed_enabled: | |
| state_dict = self.accelerator.get_state_dict(self.deepspeed) | |
| else: | |
| # TODO(ligeng): fix save_model for multi-node training on large models (e.g., Llama-70b) | |
| state_dict = self.model.state_dict() | |
| if self.args.should_save: | |
| return self.model.save_pretrained(output_dir, state_dict=state_dict) | |
| class LLaVATrainer(Trainer): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.model_accepts_loss_kwargs = True | |
| def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: | |
| if self.train_dataset is None or not has_length(self.train_dataset): | |
| return None | |
| print('AF3 sampler') | |
| sample_len_list = self.args.sample_lens | |
| seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed | |
| num_replicas = self.args.world_size | |
| rank = self.args.process_index | |
| longvila_sampler = self.args.longvila_sampler | |
| if self.args.group_by_modality_length: | |
| sampler = LengthGroupedVILADistributedSampler | |
| if not isinstance(self.train_dataset, ConcatDataset): | |
| lengths = self.train_dataset.modality_lengths | |
| else: | |
| lengths = [] | |
| for d in self.train_dataset.datasets: | |
| lengths += d.modality_lengths | |
| return sampler( | |
| self.train_dataset, | |
| lengths=lengths, | |
| num_replicas=num_replicas, | |
| rank=rank, | |
| seed=seed, | |
| batch_size=self.args.train_batch_size, | |
| sp_degree=self.args.seq_parallel_size, | |
| gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
| group_by_modality=True | |
| ) | |
| else: | |
| sampler = LongVILADistributedSampler if longvila_sampler else VILADistributedSampler | |
| return sampler( | |
| self.train_dataset, | |
| num_replicas=num_replicas, | |
| rank=rank, | |
| seed=seed, | |
| batch_size=self.args.train_batch_size, | |
| sample_len_list=sample_len_list, | |
| sp_degree=self.args.seq_parallel_size, | |
| gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
| ) | |
| def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: | |
| if self.eval_dataset is None or not has_length(self.eval_dataset): | |
| return None | |
| sample_len_list = self.args.eval_sample_lens | |
| seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed | |
| return VILADistributedSampler( | |
| eval_dataset, | |
| num_replicas=self.args.world_size, | |
| rank=self.args.process_index, | |
| seed=seed, | |
| batch_size=self.args.eval_batch_size, | |
| sample_len_list=sample_len_list, | |
| gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
| ) | |
| def _inner_training_loop(self, batch_size: Optional[int] = None, *args, **kwargs): | |
| # NOTE(zhijianl): In the latest transformers, if the batch size in the training arguments differs from | |
| # the one in the training state, the batch size from the state is used by default. This can be | |
| # problematic when resuming with different batch sizes or gradient accumulation steps. To prevent this, | |
| # we enforce using the batch size specified in the training arguments. | |
| batch_size = self.args.train_batch_size | |
| return super()._inner_training_loop(batch_size, *args, **kwargs) | |
| def create_optimizer(self): | |
| """ | |
| Setup the optimizer. | |
| We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
| Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
| """ | |
| if is_sagemaker_mp_enabled(): | |
| return super().create_optimizer() | |
| # if self.sharded_ddp == ShardedDDPOption.SIMPLE: | |
| # return super().create_optimizer() | |
| opt_model = self.model | |
| if self.optimizer is None: | |
| decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) | |
| decay_parameters = [name for name in decay_parameters if "bias" not in name] | |
| if self.args.mm_projector_lr is not None: | |
| projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n in decay_parameters and n not in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n in decay_parameters and n in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| "lr": self.args.mm_projector_lr, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and n in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| "lr": self.args.mm_projector_lr, | |
| }, | |
| ] | |
| elif self.args.vision_tower_lr is not None: | |
| projector_parameters = [name for name, _ in opt_model.named_parameters() if "vision_tower" in name] | |
| # projector_lora_A_parameters = [name for name in projector_parameters if "lora_A" in name] | |
| # projector_lora_B_parameters = [name for name in projector_parameters if "lora_B" in name] | |
| # other_lora_A_parameters = [name for name in opt_model.named_parameters() if "lora_A" in name and name not in projector_parameters] | |
| # other_lora_B_parameters = [name for name in opt_model.named_parameters() if "lora_B" in name and name not in projector_parameters] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n in decay_parameters and n not in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n in decay_parameters and n in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| "lr": self.args.vision_tower_lr, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and n in projector_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| "lr": self.args.vision_tower_lr, | |
| }, | |
| ] | |
| else: | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in opt_model.named_parameters() | |
| if (n not in decay_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | |
| if 0: # self.sharded_ddp == ShardedDDPOption.SIMPLE: | |
| self.optimizer = OSS( | |
| params=optimizer_grouped_parameters, | |
| optim=optimizer_cls, | |
| **optimizer_kwargs, | |
| ) | |
| else: | |
| self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
| if optimizer_cls.__name__ == "Adam8bit": | |
| import bitsandbytes | |
| manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | |
| skipped = 0 | |
| for module in opt_model.modules(): | |
| if isinstance(module, nn.Embedding): | |
| skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) | |
| logger.info(f"skipped {module}: {skipped/2**20}M params") | |
| manager.register_module_override(module, "weight", {"optim_bits": 32}) | |
| logger.debug(f"bitsandbytes: will optimize {module} in fp32") | |
| logger.info(f"skipped: {skipped/2**20}M params") | |
| return self.optimizer | |
| def save_model(self, output_dir: Optional[str], _internal_call: bool): | |
| ## save tuned model separately | |
| if self.is_deepspeed_enabled: | |
| state_dict = self.accelerator.get_state_dict(self.deepspeed) | |
| else: | |
| # TODO(ligeng): fix save_model for multi-node training on large models (e.g., Llama-70b) | |
| state_dict = self.model.state_dict() | |
| if self.args.lora_enable: | |
| non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters()) | |
| os.makedirs(output_dir, exist_ok=True) | |
| torch.save( | |
| non_lora_state_dict, | |
| os.path.join(output_dir, "non_lora_trainables.bin"), | |
| ) | |
| # config | |
| self.model._name_or_path = output_dir | |
| self.model.architectures = [self.model.__class__.__name__] | |
| self.model.config.save_pretrained(output_dir) | |
| if self.args.should_save: | |
| return self.model.save_pretrained(output_dir, state_dict=state_dict) | |
| def log(self, logs: Dict[str, float]) -> None: | |
| """ | |
| Log `logs` on the various objects watching training. | |
| Subclass and override this method to inject custom behavior. | |
| Args: | |
| logs (`Dict[str, float]`): | |
| The values to log. | |
| """ | |
| if self.state.epoch is not None: | |
| logs["epoch"] = round(self.state.epoch, 2) | |
| if self.args.include_num_input_tokens_seen: | |
| logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen | |
| output = {**logs, **{"step": self.state.global_step}} | |
| self.state.log_history.append(output) | |
| if self.args.debug_e2e and self.control.should_training_stop: | |
| # Only save log history if the current process is rank 0 | |
| if dist.get_rank() == 0: | |
| with open(f"{self.args.output_dir}/log_history.json", "w") as f: | |
| json.dump(self.state.log_history, f, indent=4) | |
| self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) | |