code_SAS_VLM2Vec / src /trainer_add_CRD.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
# VLM2Vec/src/trainer_add_CRD.py
import collections
import contextlib
import functools
import shutil
import sys
import time
from datetime import timedelta
from packaging import version
from accelerate import skip_first_batches, DistributedType, InitProcessGroupKwargs
from transformers import PretrainedConfig
from transformers.trainer import Trainer, TRAINING_ARGS_NAME, TRAINER_STATE_NAME
import torch.distributed as dist
from typing import Optional
import os
import torch
import math
import torch.nn as nn
from src.data.collator.train_collator import split_vlm_inputs, get_dense_rep, split_and_process_vlm_inputs
from src.model.model_add_CRD import MMEBModel
from src.loss_add_CRD import SimpleContrastiveLoss, DistributedContrastiveLoss, MultiLayerCRDLoss, DistributedMultiLayerCRDLoss
from src.grad_cache.grad_cache import GradCache
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.trainer_callback import (
ExportableState,
TrainerState,
)
from transformers.trainer_utils import (
TrainOutput,
has_length,
speed_metrics, seed_worker,
)
from transformers.trainer_pt_utils import (
get_model_param_count,
)
from transformers.trainer import FSDP_MODEL_NAME
from transformers.utils import (
XLA_FSDPV2_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_torch_xla_available,
logging, is_sagemaker_mp_enabled,
CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME
)
from src.utils import batch_to_device
from src.utils import print_master, print_rank
if is_apex_available():
from apex import amp
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
from torch_xla import __version__ as XLA_VERSION
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
pass
else:
IS_XLA_FSDPV2_POST_2_2 = False
logger = logging.get_logger(__name__)
# =============== Helper: locate last LM block for grad diagnostics ===============
def _locate_lm_layers_modulelist(encoder):
"""
Try common paths to get ModuleList of LM blocks in Qwen/LLaMA-like models.
"""
candidates = [
("model", "language_model", "layers"),
("model", "model", "layers"),
("model", "layers"),
("language_model", "layers"),
("transformer", "layers"),
]
for path in candidates:
obj = encoder
ok = True
for p in path:
if hasattr(obj, p):
obj = getattr(obj, p)
else:
ok = False
break
if ok and isinstance(obj, torch.nn.ModuleList) and len(obj) > 0:
return obj
return None
def _grad_norm(params):
tot = 0.0
for p in params:
if p.grad is not None:
g = p.grad.detach().float()
tot += (g * g).sum().item()
return math.sqrt(tot) if tot > 0 else 0.0
# ================================================================================
class MMEBTrainer(Trainer):
def __init__(self, *args, **kwargs):
super(MMEBTrainer, self).__init__(*args, **kwargs)
ws = dist.get_world_size() if dist.is_initialized() else 1
self.is_ddp = dist.is_initialized() and ws > 1
self._dist_loss_scale_factor = ws if self.is_ddp else 1
self.processor = self.processing_class
def get_batch_samples(self, epoch_iterator, num_batches):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples += [next(epoch_iterator)]
except StopIteration:
break
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
except (TypeError, AttributeError):
pass
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.item()
return batch_samples, num_items_in_batch
def compute_loss(self, model, inputs, *args, **kwargs):
qry_inputs, tgt_inputs = inputs
return model(qry=qry_inputs, tgt=tgt_inputs)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
os.makedirs(output_dir, exist_ok=True)
if state_dict is None:
state_dict = self.model.state_dict()
prefix = 'encoder.'
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
self.model.encoder.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# override original trainer's method
if self.train_dataset is None or not has_length(self.train_dataset):
return None
return RandomSampler(self.train_dataset)
def get_train_dataloader(self) -> DataLoader:
"""
override original trainer's method to disable self.accelerator.prepare since it will wrap DataLoaderDispatcher and lead to
(1) `RuntimeError: You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`.`
(2) all outputs of dataloader must be tensors
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
train_dataset = self._remove_unused_columns(train_dataset, description="training")
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
else:
dataloader_params["sampler"] = None
dataloader_params["shuffle"] = False
dataloader_params["drop_last"] = True
dataloader_params["prefetch_factor"] = None # # tune on both prefetch_factor and persistent_workers will cause hang at epoch2
return DataLoader(train_dataset, **dataloader_params)
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
self.model_args.checkpoint_path = resume_from_checkpoint
logger.info(f"Loading checkpoint from {resume_from_checkpoint}")
self.model = MMEBModel.load(self.model_args)
self.model_wrapped = self.model
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
if self.state.train_batch_size != self._train_batch_size:
from accelerate.utils import release_memory
(self.model_wrapped,) = release_memory(self.model_wrapped)
self.model_wrapped = self.model
# Check for DeepSpeed *after* the intial pass and modify the config
if self.is_deepspeed_enabled:
# Temporarily unset `self.args.train_batch_size`
original_bs = self.args.per_device_train_batch_size
self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
self.propagate_args_to_deepspeed(True)
self.args.per_device_train_batch_size = original_bs
self.state.train_batch_size = self._train_batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
len_dataloader = None
num_train_tokens = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
)
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
self.lr_scheduler = None
self._created_lr_scheduler = False
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState(
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]
)
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
model = self._wrap_model(self.model_wrapped)
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation:
if use_accelerator_prepare:
self._fsdp_qlora_plugin_updates()
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# In this case we are in DDP + LOMO, which should be supported
self.optimizer = self.accelerator.prepare(self.optimizer)
if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# @ruimeng use steps_trained_in_current_epoch to skip batches for finding buggy data
# steps_trained_in_current_epoch = 42
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
if hasattr(epoch_dataloader.dataset, "set_epoch"):
epoch_dataloader.dataset.set_epoch(epoch)
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_dataloader)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
epoch_iterator = iter(epoch_dataloader)
remainder = num_examples % args.gradient_accumulation_steps
num_items_in_batch = None
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for i, inputs in enumerate(batch_samples):
step += 1
total_batched_samples += 1
dataset_stat = collections.Counter(inputs[0]['global_dataset_name'])
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
do_sync_step = is_last_step_and_steps_less_than_grad_acc or (
total_batched_samples % args.gradient_accumulation_steps == 0
)
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.accelerator.gradient_state._set_sync_gradients(True)
if self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if main_input_name not in inputs:
logger.warning("Tried to track the number of tokens seen, however the current model is not configured properly.")
else:
input_tokens = inputs[main_input_name].numel()
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if (
args.logging_nan_inf_filter
and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
if tr_loss.device != tr_loss_step.device:
raise ValueError(f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}")
tr_loss = tr_loss + tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
if do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(True)
if args.max_grad_norm is not None and args.max_grad_norm > 0:
if self.use_apex:
_grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), args.max_grad_norm)
else:
_grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if (is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED):
grad_norm = model.get_global_grad_norm()
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
try:
self._maybe_log_teacher_grad(model)
except Exception as e:
logger.warning(f"teacher grad log failed (ignored): {e}")
self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time())
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
if step < 0:
logger.warning("There seems not to be a single sample in your epoch_iterator, stopping training.")
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, time.time())
if self.control.should_training_stop:
break
if args.past_index and hasattr(self, "_past"):
delattr(self, "_past")
logger.info("\n\nTraining completed.\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
if is_torch_xla_available():
xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()
self._load_best_model()
self._total_loss_scalar += tr_loss.item()
effective_global_step = max(self.state.global_step, 0.001)
train_loss = self._total_loss_scalar / effective_global_step
metrics = speed_metrics(
"train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps, num_tokens=num_train_tokens,
)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
self._finish_current_push()
if self.neftune_noise_alpha is not None:
self._deactivate_neftune(self.model)
return TrainOutput(self.state.global_step, train_loss, metrics)
class GradCacheLateProcessTrainer(MMEBTrainer):
"""
Adapted from gradcache repo.
"""
def __init__(self, *args, **kwargs):
self.max_length = kwargs.get("max_length", 512)
if "max_length" in kwargs:
del kwargs["max_length"]
self.model_args = kwargs.get("model_args", None)
if "model_args" in kwargs:
del kwargs["model_args"]
super(GradCacheLateProcessTrainer, self).__init__(*args, **kwargs)
ws = dist.get_world_size() if dist.is_initialized() else 1
self.is_ddp = dist.is_initialized() and ws > 1
self._dist_loss_scale_factor = ws if self.is_ddp else 1
loss_fn_cls = DistributedMultiLayerCRDLoss if self.is_ddp else MultiLayerCRDLoss
crd_layers = getattr(self.args, "crd_layers", None)
if isinstance(crd_layers, str) and len(crd_layers.strip()) > 0:
crd_layers = [int(x.strip()) for x in crd_layers.split(",") if x.strip() != ""]
else:
crd_layers = None
# ADDED: allow detach_teacher from args
detach_teacher = getattr(self.args, "crd_detach_teacher", True)
# 读取可选开关(也可以加到 TrainingArguments 后再改这里的读取)
crd_side = os.getenv("CRD_SIDE", "both") # "both"|"qry"|"tgt"
queue_size = int(os.getenv("CRD_QUEUE_SIZE", "0") or 0)
self.loss_fn = loss_fn_cls(
temperature=self.model.temperature,
weights=getattr(self.model, "supervise_weights", None),
crd_weight=getattr(self.args, "crd_weight", 0.2),
crd_temperature=getattr(self.args, "crd_temperature", 0.07),
crd_layers=crd_layers,
detach_teacher=detach_teacher,
crd_side=crd_side,
queue_size=queue_size
)
self.gc = GradCache(
models=[self.model, self.model],
chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size],
loss_fn=self.loss_fn,
split_input_fn=split_and_process_vlm_inputs,
get_rep_fn=get_dense_rep, # 返回 [B,K,D]
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)
# ADDED: cache for debug (last block params)
self._last_block_params_cache = None
def _infer_device(self, model):
if hasattr(model, "device") and model.device is not None:
return model.device
try:
return next(model.parameters()).device
except StopIteration:
return self.args.device
def _batch_size(self, batch: dict) -> int:
for k, v in batch.items():
if torch.is_tensor(v) and v.dim() > 0:
return v.size(0)
if isinstance(v, list):
return len(v)
raise ValueError("Cannot infer batch size from batch keys.")
def _slice_batch(self, batch: dict, size: int, offset: int = 0):
out = {}
end = offset + size
for k, v in batch.items():
if torch.is_tensor(v) and v.dim() > 0 and v.size(0) >= end:
out[k] = v[offset:end]
elif isinstance(v, list) and len(v) >= end:
out[k] = v[offset:end]
else:
out[k] = v
return out
def _maybe_log_teacher_grad(self, model):
"""
统计“最后一层 Block”的梯度范数,用于确认教师是否被更新。
可用环境变量 LOG_TEACHER_GRAD=0 关闭。
"""
if os.getenv("LOG_TEACHER_GRAD", "1") not in ("1", "true", "True"):
return
try:
params_last = self._get_last_block_params()
tgn = _grad_norm(params_last) # 已转 float
# 1) 控制台
print_master(f"[teacher_grad] step={self.state.global_step} norm={tgn:.6f}")
# 2) HF metrics(会和 loss、grad_norm 一起进到 log_history/W&B)
self.log({"teacher_grad_norm": tgn})
except Exception as e:
logger.warning(f"teacher grad log failed: {e}")
def _forward_reps_in_chunks(self, model, batch: dict, side: str, chunk_size: int, device):
B = self._batch_size(batch)
outs = []
use_bf16 = getattr(self.args, "bf16", False)
dev_type = "cuda" if "cuda" in str(device) else "cpu"
# 不切换 eval,不使用 no_grad;分块前向后立刻 detach
for s in range(0, B, max(1, chunk_size or B)):
bs = min(chunk_size or B, B - s)
sub = self._slice_batch(batch, size=bs, offset=s)
with torch.autocast(device_type=dev_type, dtype=torch.bfloat16, enabled=use_bf16):
if side == "qry":
o = model(qry=sub)
reps = o["qry_reps"] # [bs,K,D]
else:
o = model(tgt=sub)
reps = o["tgt_reps"] # [bs,K,D]
outs.append(reps.detach()) # 关键:立即 detach
del o, reps
# 如显存非常紧张可偶尔清理缓存(频繁调用会慢)
# if s % (4 * (chunk_size or B)) == 0: torch.cuda.empty_cache()
return torch.cat(outs, dim=0) # [B,K,D]
def _norm_weights(self, weights_list, K, device):
if weights_list is None:
return torch.ones(K, device=device) / K
w = torch.tensor(list(weights_list), dtype=torch.float32, device=device)
w = torch.clamp(w, min=0)
s = float(w.sum())
return w / (s if s > 0 else 1.0)
def _crd_indices(self, K: int):
# 使用 args.crd_layers(相对 supervise_layers 的 0-based 索引);None -> 所有非最后层
cl = getattr(self.args, "crd_layers", None)
if cl is None or (isinstance(cl, str) and cl.strip() == ""):
return list(range(0, max(0, K - 1)))
if isinstance(cl, str):
idxs = [int(x.strip()) for x in cl.split(",") if x.strip() != ""]
else:
idxs = list(cl)
# 过滤掉最后一层 K-1
out = []
for i in idxs:
if i < 0: i = K + i
if 0 <= i < K - 1:
out.append(i)
return sorted(set(out))
def _single_gpu_chunked_step(self, model, queries: dict, targets: dict, device):
"""
手写分块版:两阶段梯度
A) 对 query 分块,backward 到 q(y / teacher 常量)
B) 对 target 分块,backward 到 p(x / teacher 常量)
返回一个不带梯度的标量 loss 供日志使用。
"""
# 预先无梯度计算“常量库”:全 batch 的多层表示(占用小)
q_all = self._forward_reps_in_chunks(model, queries, side="qry", chunk_size=self.args.gc_q_chunk_size, device=device) # [B,K,D]
p_all = self._forward_reps_in_chunks(model, targets, side="tgt", chunk_size=self.args.gc_p_chunk_size, device=device) # [B,K,D]
B, K, D = q_all.shape
w_ret = self._norm_weights(getattr(self.model, "supervise_weights", None), K, device)
crd_idxs = self._crd_indices(K)
w_crd = self._norm_weights([w_ret[i].item() for i in crd_idxs] if len(crd_idxs) > 0 else [1.0], max(1, len(crd_idxs)), device)
temp = float(getattr(self.model, "temperature", 0.02))
beta = float(getattr(self.loss_fn, "runtime_beta", getattr(self.loss_fn, "crd_weight", getattr(self.args, "crd_weight", 0.2))))
crd_temp = float(getattr(self.loss_fn, "crd_temperature", getattr(self.args, "crd_temperature", 0.07)))
detach_teacher = bool(getattr(self.args, "crd_detach_teacher", True))
use_bf16 = getattr(self.args, "bf16", False)
dev_type = "cuda" if "cuda" in str(device) else "cpu"
# 教师(最后一层)常量
tq_all = q_all[:, K - 1, :].detach() if detach_teacher else q_all[:, K - 1, :]
tp_all = p_all[:, K - 1, :].detach() if detach_teacher else p_all[:, K - 1, :]
total_loss_scalar = 0.0
# Phase-A: 查询侧分块,更新 q
q_chunk = max(1, self.args.gc_q_chunk_size or B)
for s in range(0, B, q_chunk):
bs = min(q_chunk, B - s)
sub_q = self._slice_batch(queries, size=bs, offset=s)
labels = torch.arange(s, s + bs, device=device, dtype=torch.long)
with torch.autocast(device_type=dev_type, dtype=torch.bfloat16, enabled=use_bf16):
o_q = model(qry=sub_q) # [bs,K,D],带梯度
qk = o_q["qry_reps"]
# L_ret (q_k vs p_all_k)
L_ret_q = 0.0
for k_idx in range(K):
logits = torch.matmul(qk[:, k_idx, :], p_all[:, k_idx, :].transpose(0, 1)) / temp # [bs,B]
Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean")
L_ret_q = L_ret_q + w_ret[k_idx] * Lk
# L_crd_q (q_k vs tq_all)
L_crd_q = 0.0
for j, k_idx in enumerate(crd_idxs):
logits = torch.matmul(qk[:, k_idx, :], tq_all.transpose(0, 1)) / crd_temp
Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean")
L_crd_q = L_crd_q + w_crd[j] * Lk
L_q = L_ret_q + beta * L_crd_q
# 反传仅更新 q 分支
L_q.backward()
total_loss_scalar += float(L_q.detach()) * (bs / B)
del qk, o_q, L_q, L_ret_q, L_crd_q
# Phase-B: 候选侧分块,更新 p
p_chunk = max(1, self.args.gc_p_chunk_size or B)
for s in range(0, B, p_chunk):
bs = min(p_chunk, B - s)
sub_p = self._slice_batch(targets, size=bs, offset=s)
labels = torch.arange(s, s + bs, device=device, dtype=torch.long)
with torch.autocast(device_type=dev_type, dtype=torch.bfloat16, enabled=use_bf16):
o_p = model(tgt=sub_p) # [bs,K,D],带梯度
pk = o_p["tgt_reps"]
# L_ret (p_k vs q_all_k)
L_ret_p = 0.0
for k_idx in range(K):
logits = torch.matmul(pk[:, k_idx, :], q_all[:, k_idx, :].transpose(0, 1)) / temp
Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean")
L_ret_p = L_ret_p + w_ret[k_idx] * Lk
# L_crd_p (p_k vs tp_all)
L_crd_p = 0.0
for j, k_idx in enumerate(crd_idxs):
logits = torch.matmul(pk[:, k_idx, :], tp_all.transpose(0, 1)) / crd_temp
Lk = torch.nn.functional.cross_entropy(logits, labels, reduction="mean")
L_crd_p = L_crd_p + w_crd[j] * Lk
L_p = L_ret_p + beta * L_crd_p
L_p.backward()
total_loss_scalar += float(L_p.detach()) * (bs / B)
del pk, o_p, L_p, L_ret_p, L_crd_p
# 返回常数张量用于日志(避免 HF 再做 backward)
return torch.tensor(total_loss_scalar, device=device, dtype=torch.float32, requires_grad=False)
# ADDED: dynamic beta warmup and grad diagnostics
def _apply_crd_warmup(self):
target_beta = getattr(self.args, "crd_weight", 0.2)
warm_steps = getattr(self.args, "crd_warmup_steps", 0)
# self.state.global_step 在 accumulation 前是上一次的值,用它也足够
step = max(0, getattr(self.state, "global_step", 0))
if warm_steps and step < warm_steps:
beta = target_beta * float(step + 1) / float(warm_steps)
else:
beta = target_beta
# 同时写入 runtime_beta 与 crd_weight,保证兼容不同实现
setattr(self.loss_fn, "runtime_beta", beta)
if hasattr(self.loss_fn, "crd_weight"):
self.loss_fn.crd_weight = beta
def _get_last_block_params(self):
if self._last_block_params_cache is not None:
return self._last_block_params_cache
layers = _locate_lm_layers_modulelist(self.model.encoder)
if layers is None or len(layers) == 0:
logger.warning("Could not locate LM layers for grad debug; will use encoder parameters as fallback.")
params = list(self.model.encoder.parameters())
else:
params = list(layers[-1].parameters())
self._last_block_params_cache = params
return params
def _debug_teacher_grad(self, queries, targets, model):
# 单卡禁用,避免两次整批前向导致 OOM
if not self.is_ddp:
return
dbg_every = int(getattr(self.args, "crd_debug_every", 0) or 0)
if dbg_every <= 0 or (self.state.global_step % dbg_every) != 0:
return
last_params = self._get_last_block_params()
beta_saved = getattr(self.loss_fn, "crd_weight", 0.0)
if self.is_ddp:
# DDP: 两次经过 GradCache
self.model.zero_grad(set_to_none=True)
self.loss_fn.crd_weight = 0.0
_ = self.gc(queries, targets, no_sync_except_last=True)
gn_ret = _grad_norm(last_params)
self.model.zero_grad(set_to_none=True)
self.loss_fn.crd_weight = beta_saved
_ = self.gc(queries, targets, no_sync_except_last=True)
gn_all = _grad_norm(last_params)
else:
# 单卡: 直接模型前向 + loss_fn
self.model.zero_grad(set_to_none=True)
# 取多层表征
out_q = model(qry=queries["qry"])
out_p = model(tgt=targets["tgt"])
x, y = out_q["qry_reps"], out_p["tgt_reps"]
self.loss_fn.crd_weight = 0.0
loss = self.loss_fn(x, y)
loss.backward()
gn_ret = _grad_norm(last_params)
self.model.zero_grad(set_to_none=True)
self.loss_fn.crd_weight = beta_saved
loss = self.loss_fn(x, y)
loss.backward()
gn_all = _grad_norm(last_params)
print_master(f"[CRD-Debug] step={self.state.global_step} grad_norm(last-block): RET={gn_ret:.6f}, RET+CRD={gn_all:.6f}, delta={max(0.0, gn_all-gn_ret):.6f}")
self.model.zero_grad(set_to_none=True)
def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
model.train()
queries, targets = inputs
device = self._infer_device(model)
queries = batch_to_device(queries, device)
targets = batch_to_device(targets, device)
queries, targets = {'qry': queries}, {'tgt': targets}
# 动态 CRD warmup
self._apply_crd_warmup()
# 可选梯度诊断
try:
self._debug_teacher_grad(queries, targets, model)
except Exception as e:
logger.warning(f"CRD grad debug failed (ignored): {e}")
if self.is_ddp:
# 多卡:使用 GradCache(要求模型已被 DDP 包裹,HF 会在 _wrap_model + accelerator.prepare 后处理)
self.gc.models = [model, model]
loss = self.gc(queries, targets, no_sync_except_last=True)
else:
# 单卡:手写分块两阶段,避免整批前向 OOM
loss = self._single_gpu_chunked_step(model, queries["qry"], targets["tgt"], device)
return loss / self._dist_loss_scale_factor
def _save(self, output_dir: Optional[str] = None, state_dict=None):
print_master(f"Saving model to {output_dir}")
os.makedirs(output_dir, exist_ok=True)
if state_dict is None:
state_dict = self.model.state_dict()
prefix = 'encoder.'
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
self.model.encoder.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
self.model.encoder.config.to_json_file(os.path.join(output_dir, 'config.json'))