deepspeed / src /sca_seq2seq_trainer.py
xingzhikb's picture
init
002bd9b
import logging
import os
from typing import Optional, List, Dict, Union, Tuple, Any, NamedTuple, Mapping
import time
import math
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import hydra
from hydra.utils import instantiate
from datasets import DatasetDict, load_dataset, IterableDatasetDict
from omegaconf import DictConfig, OmegaConf
from .data.transforms import SamCaptionerDataTransform
from .data.collator import SamCaptionerDataCollator
from .arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments
from .models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor
from transformers.trainer_utils import get_last_checkpoint
from transformers import set_seed, Seq2SeqTrainer, GenerationConfig
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.trainer import (
speed_metrics,
deepspeed_init,
is_torch_tpu_available,
has_length,
find_batch_size,
nested_concat,
nested_numpify,
IterableDatasetShard,
EvalLoopOutput,
denumpify_detensorize,
is_sagemaker_mp_enabled,
get_parameter_names,
ALL_LAYERNORM_LAYERS,
Trainer,
EvalPrediction,
TrainerState,
deepspeed_load_checkpoint,
get_model_param_count,
TRAINER_STATE_NAME,
skip_first_batches,
sys,
HPSearchBackend,
hp_params,
RandomSampler,
# is_torch_less_than_1_11,
ParallelMode,
dist,
shutil,
TrainOutput,
PREFIX_CHECKPOINT_DIR,
SCHEDULER_NAME,
SCALER_NAME,
reissue_pt_warnings,
)
from functools import wraps
from collections import defaultdict
try:
from transformers.trainer import xm, met, pl
except ImportError:
pass
try:
from transformers.trainer import amp
except ImportError:
pass
try:
from transformers.trainer import smp_forward_backward
except ImportError:
pass
try:
from transformers.trainer import smp
except ImportError:
pass
try:
from transformers.trainer import OSS
except ImportError:
pass
# NOTE: bump transformers from 4.30.2 to 4.36.2
try:
from transformers.trainer import (
ShardedDDPOption,
nested_truncate,
tqdm,
DistributedSampler,
)
except ImportError:
pass
try:
from transformers.trainer_callback import TrainerCallback
except ImportError:
pass
try:
from transformers.trainer_seq2seq import is_deepspeed_zero3_enabled
except ImportError:
pass
# NOTE: Fix the resume of DS optimizer + HF scheduler. https://github.com/huggingface/transformers/pull/25863/files
def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None
import importlib.util
import warnings
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
logger = logging.getLogger(__name__)
SAVING_FINISHED_FLAG = "saving_finished.flag"
class InferenceLoopOutput(NamedTuple):
logits: Optional[Dict]
label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
metadata: Optional[Dict]
batch_num_regions_shape: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
num_samples: Optional[int]
class FunctionTimers:
def __init__(self):
import time
import numpy as np
self.timers = defaultdict(list)
def get_timer(self, f):
@wraps(f)
def _decorate(*args, **kwargs):
start = time.perf_counter()
ret = f(*args, **kwargs)
end = time.perf_counter()
if f.__name__ not in self.timers:
self.timers[f.__name__] = []
self.timers[f.__name__].append((end - start) * 1000)
return ret
return _decorate
def clear(self):
for k in self.timers:
self.timers[k] = []
def report(self):
return {f"{k}_in_ms": np.mean(v) for k, v in self.timers.items()}
class SCASeq2SeqTrainer(Seq2SeqTrainer):
# NOTE(xiaoke): Modified. Based on transformers v4.30.2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.use_legacy_prediction_loop is True:
raise ValueError(
f"Not support legacy `prediction loop` for {self.__class__.__name__}! "
"As I do not override it for region caption task."
)
# NOTE: change length_penalty, and apply its to both the model and the language model.
# NOTE: with 10 samples, length_penalty leads to loss of CIDER
# generation_config = self.model.generation_config.to_dict()
# generation_config["length_penalty"] = 0.0
# generation_config.pop("_from_model_config")
# self.model.generation_config = GenerationConfig(**generation_config)
# self.model.language_model.generation_config = self.model.generation_config
# logger.info(f"generation_config: {self.model.generation_config}")
self.function_timers = FunctionTimers()
self._prepare_inputs = self.function_timers.get_timer(self._prepare_inputs)
self.compute_loss = self.function_timers.get_timer(self.compute_loss)
self._do_backward = self.function_timers.get_timer(self._do_backward)
self.training_step = self.function_timers.get_timer(self.training_step)
# NOTE: define the compute_metric_func
# NOTE: compute_metrics = None triggers the default `prediction_loss_only=True`
# NOTE: compute_metrics should be a function, but we define the function in the trainer, so we use bool here to indicate the usage.
# NOTE: only world process zero compute the metrics, otherwise it may leads to download error.
if self.compute_metrics is True and self.is_world_process_zero():
import evaluate
self.compute_metrics_func = evaluate.load("meteor")
else:
self.compute_metrics_func = None
# NOTE: bump transformers from 4.30.2 to 4.36.2
if not hasattr(self, "is_fsdp_xla_enabled"):
self.is_fsdp_xla_enabled = False
# Copied from `Trainer`
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
if is_torch_tpu_available():
xm.mark_step()
logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
param_group_key = ["full"] + list(self.args.custom_param_lrs.keys())
param_group_values = self._get_learning_rate()
# NOTE: only keep the even idxs, because each group is divided into two sub ones, e.g., (lr_w_wd, lr_wo_wd).
param_group_values = [v for idx, v in enumerate(param_group_values) if idx % 2 == 0]
for k, v in zip(param_group_key, param_group_values):
logs[f"learning_rate/{k}"] = v
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
logs.update(self.function_timers.report())
self.function_timers.clear()
self.log(logs)
metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
# NOTE: add metric loss for best ckpt saving.
metrics_loss = {k: v for k, v in metrics.items() if k.startswith("eval_") and k.endswith("_loss")}
metrics["eval_loss"] = sum(metrics_loss.values()) / len(metrics_loss)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
# NOTE: to handel empty batch during training due to LSJ augmentation.
# We set `inputs` to None in `training_step` when the batch is empty.
if inputs is None:
logger.error("The inputs shouldn't be None in training! Thus we skip this batch of data.")
return torch.tensor(torch.nan)
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
self._do_backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
def _do_backward(self, loss):
# NOTE: bump transformers from 4.30.2 to 4.36.2
# sharded_ddp for fairseq was deprecated.
# if self.do_grad_scaling:
# self.scaler.scale(loss).backward()
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
# NOTE: START OF INFERENCE CODE
# Call order:
# 1. inference
# 2. inference_loop
# 3. inference_step
# use generate and save the outputs
def inference(
self,
inference_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "inference",
**gen_kwargs,
):
# NOTE(xiaoke): Modified. Check the tokenizer and the unk_token_id first
# We do not want to encounter the error after all the predicions are generated
if self.tokenizer is None:
raise ValueError("You need to specify a tokenizer in Trainer!")
if self.tokenizer.unk_token_id is None:
raise ValueError(f"Check the tokenizer! unk_token_id is None! {self.tokenizer}")
gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
# memory metrics - must set up as early as possible
self._memory_tracker.start()
eval_dataloader = self.get_eval_dataloader(inference_dataset)
start_time = time.time()
output = self.inference_loop(
eval_dataloader,
description="Inference",
# No point gathering the predictions if there are no metrics, otherwise we defer to
prediction_loss_only=False,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
skip_predcition_loss_after_generate=True,
)
(
batch_num_regions,
gt_captions,
pred_captions,
metadata_with_num_regions_length,
logits_with_num_regions_length,
) = self._decode_inference_outputs(output)
self._save_inference_json(
metric_key_prefix,
batch_num_regions,
gt_captions,
pred_captions,
metadata_with_num_regions_length,
logits_with_num_regions_length,
)
def _decode_inference_outputs(self, output):
# NOTE(xiaoke): Modified. Dispatch the logits.
# - `generated_tokens`: (batch_size, num_regions, num_heads, token_max_length)
# - `iou_scores`: (batch_size, num_regions, num_heads)
# Remove metrics update
logits = output.logits # Dict[str, (batch_num_regions, num_heads, ...)]
label_ids = output.label_ids # (batch_num_regions, token_max_length)
metadata = output.metadata # Dict[str, (batch_num_regions, ...)]
# NOTE: generated_tokens is removed from logits, we only have `iou_scores` left
generate_ids = logits.pop("generated_tokens") # (batch_num_regions, num_heads, token_max_length)
# NOTE(xiaoke): since we pad the labels with -100, we need to cast them back to unk_token_id
# we believe there is always a tokenizer.unk_token_id in the tokenizer
# Avoid error OverflowError: out of range integral type conversion attempted
# https://github.com/huggingface/transformers/issues/22634#issuecomment-1500429811
generate_ids = self._change_loss_token_to_unk_token(generate_ids, unk_token_id=self.tokenizer.unk_token_id)
label_ids = self._change_loss_token_to_unk_token(label_ids, unk_token_id=self.tokenizer.unk_token_id)
# NOTE(xiaoke): process generate_ids
batch_num_regions, num_heads, token_max_length = generate_ids.shape
generate_ids = generate_ids.reshape(batch_num_regions * num_heads, token_max_length)
pred_captions = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True)
# NOTE(xiaoke): process label_ids
if batch_num_regions != label_ids.shape[0]:
raise ValueError(f"batch_num_regions {batch_num_regions} != label_ids.shape[0] {label_ids.shape[0]}")
gt_captions = self.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
pred_captions = np.array(pred_captions, dtype=object).reshape(batch_num_regions, num_heads).tolist()
# NOTE(xiaoke): we asuume there is only ONE gt caption for each region
gt_captions = np.array(gt_captions, dtype=object).reshape(batch_num_regions, 1).tolist()
metadata_with_num_regions_length = {}
for k, v in metadata.items():
if len(v) != batch_num_regions:
logger.warning(
f"metadata {k} has length {len(v)}, but batch_num_regions is {batch_num_regions}, so skip it"
)
else:
metadata_with_num_regions_length[k] = v.tolist() # json does not support numpy type object
logits_with_num_regions_length = {}
for k, v in logits.items():
if len(v) != batch_num_regions:
logger.warning(f"logits {k} has length {len(v)}, but batch_num_regions is {batch_num_regions}")
else:
logits_with_num_regions_length[k] = v.tolist() # json does not support numpy type object
return (
batch_num_regions,
gt_captions,
pred_captions,
metadata_with_num_regions_length,
logits_with_num_regions_length,
)
def _save_inference_json(
self,
metric_key_prefix,
batch_num_regions,
gt_captions,
pred_captions,
metadata_with_num_regions_length,
logits_with_num_regions_length,
):
# NOTE(xiaoke): the output json follows the format of https://github.com/CannyLab/vdtk
output_json = []
for idx in range(batch_num_regions):
output_json.append(
{
"_id": idx,
"split": "inference",
"references": gt_captions[idx],
"candidates": pred_captions[idx],
"metadata": {k: v[idx] for k, v in metadata_with_num_regions_length.items()},
"logits": {k: v[idx] for k, v in logits_with_num_regions_length.items()},
}
)
import json
infer_json_dir = os.path.join(self.args.output_dir, "infer")
os.makedirs(infer_json_dir, exist_ok=True)
infer_json_file = os.path.join(infer_json_dir, f"infer-{metric_key_prefix}.json")
# TODO: only the very first process will write the file
if self.is_world_process_zero():
with open(infer_json_file, "w") as f:
json.dump(output_json, f, indent=4)
@staticmethod
def _change_loss_token_to_unk_token(tokens, unk_token_id, padding_index=-100):
tokens[tokens == padding_index] = unk_token_id
return tokens
def inference_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
skip_predcition_loss_after_generate: Optional[bool] = None,
) -> InferenceLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
args = self.args
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
# NOTE: otherwise we will get the OOM due to fp32.
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = self.args.eval_batch_size
logger.info(f"***** Running {description} *****")
if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")
logger.info(f" Num examples for process ({self.args.process_index}) = {len(dataloader) * batch_size}")
model.eval()
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if args.past_index >= 0:
self._past = None
# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
inputs_host = None
# NOTE(xiaoke): Modified. We need to save the inputs for ids
metadata_host = None
batch_num_regions_shape_host = None
# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# NOTE(xiaoke): Modified. We need to save the inputs for ids
all_metadata = None
all_batch_num_regions_shape = None
# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0
# Main inference loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# NOTE(xiaoke): Modified. We need to save the inputs for ids
metadata = None
for k, v in inputs.items():
if k.startswith("metadata_") and isinstance(v, torch.Tensor):
if metadata is None:
metadata = {}
# metadata[k] = v.flatten(0, 1) if len(v.shape) > 1 else v
metadata[k] = v
metadata = self._prepare_input(metadata)
# NOTE: skip_predcition_loss_after_generate=True
# Prediction step
loss, logits, batch_num_regions_shape, labels = self.inference_step(
model,
inputs,
prediction_loss_only,
ignore_keys=ignore_keys,
skip_predcition_loss_after_generate=skip_predcition_loss_after_generate,
)
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
# NOTE(xiaoke): While our outputs has four dim `(PADDED_batch_size, PADDED_num_regions, token_length, ...)`,
# and we squash the first two dims. In the `gather` function, we assume the tensors has the same shape.
# Thus, we need to `unsqueeze` one dim ahead, and recover the results with the batch_size-num_regions number pair.
if is_torch_tpu_available():
xm.mark_step()
# Update containers on host
if loss is not None:
# # NOTE(xiaoke): Modified. PRETEND its shape is (batch_size, token_length, ...) which the trainer expects.
# # NOTE(xiaoke): we do not add the `num_heads` dim, since they are the same across the batch.
# # Thus taking their mean is the same as with multiple heads.
# assert len(batch_num_regions_shape) == 1
# losses = loss.repeat(batch_num_regions_shape[0].tolist())
# losses = self._pad_across_processes(losses)
# losses = self._nested_gather(losses)
# # NOTE(xiaoke): Modified. We need to pad the `token_length` dim, since they may be different across batches
# losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
# NOTE: compute batch-wise average loss
losses = self._nested_gather(loss.repeat(batch_size))
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if labels is not None:
# NOTE: bump transformers from 4.30.2 to 4.36.2
# labels = self._pad_across_processes(labels)
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if inputs_decode is not None:
# NOTE: bump transformers from 4.30.2 to 4.36.2
# inputs_decode = self._pad_across_processes(inputs_decode)
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self._nested_gather(inputs_decode)
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
if logits is not None:
# NOTE: bump transformers from 4.30.2 to 4.36.2
# logits = self._pad_across_processes(logits)
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self._nested_gather(logits)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
# NOTE(xiaoke): Modified. We need to save the inputs for ids
if metadata is not None:
# NOTE: bump transformers from 4.30.2 to 4.36.2
# metadata = self._pad_across_processes(metadata)
metadata = self.accelerator.pad_across_processes(metadata, dim=1, pad_index=-100)
metadata = self._nested_gather(metadata)
metadata_host = (
metadata if metadata_host is None else nested_concat(metadata_host, metadata, padding_index=-100)
)
# NOTE(xiaoke): Modified. We need to save the batch-num_regions shape to recover the results
if batch_num_regions_shape is not None:
batch_num_regions_shape = self._nested_gather(batch_num_regions_shape)
batch_num_regions_shape_host = (
batch_num_regions_shape
if batch_num_regions_shape_host is None
else torch.concat((batch_num_regions_shape_host, batch_num_regions_shape), dim=0)
)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)
# NOTE(xiaoke): Modified. We need to save the inputs for ids
if metadata_host is not None:
metadata = nested_numpify(metadata_host)
all_metadata = (
metadata if all_metadata is None else nested_concat(all_metadata, metadata, padding_index=-100)
)
# NOTE(xiaoke): Modified. We need to save the batch-num_regions shape to recover the results
if batch_num_regions_shape_host is not None:
batch_num_regions_shape = nested_numpify(batch_num_regions_shape_host)
all_batch_num_regions_shape = (
batch_num_regions_shape
if all_batch_num_regions_shape is None
else torch.concat(all_batch_num_regions_shape, batch_num_regions_shape, padding_index=-100)
)
# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
# NOTE(xiaoke): Modified. We need to save the inputs for ids
metadata_host = None
batch_num_regions_shape_host = None
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# NOTE(xiaoke): Modified. We need to save the inputs for ids
if metadata_host is not None:
metadata = nested_numpify(metadata_host)
all_metadata = (
metadata if all_metadata is None else nested_concat(all_metadata, metadata, padding_index=-100)
)
# NOTE(xiaoke): Modified. We need to save the batch-num_regions shape to recover the results
if batch_num_regions_shape_host is not None:
batch_num_regions_shape = nested_numpify(batch_num_regions_shape_host)
all_batch_num_regions_shape = (
batch_num_regions_shape
if all_batch_num_regions_shape is None
else nested_concat(all_batch_num_regions_shape, batch_num_regions_shape, padding_index=-100)
)
# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
# NOTE(xiaoke): Modified. Log wrong number of samples
logger.warning(
f"Your dataset doesn't implement `__len__`. Use dataloader instead, Inference will not check all elements."
)
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
# NOTE(xiaoke): Modified. Log wrong number of samples
logger.warning(
f"Your dataset doesn't implement `__len__`. Use one process observed data. Inference will not check all elements."
)
num_samples = observed_num_examples
if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples
# NOTE(xiaoke): Modified. In region caption task, the prediction has two batch dimensions,
# the first one is the batch size, the second one is the number of regions.
# we need to truncate the results based on both dims.
# - all_batch_num_regions_shape: (batch_steps, 2), one batch_step has a batch of data
# - all_losses: (PADDED_batch_size, PADDED_num_regions)
# - all_preds: (PADDED_batch_size, PADDED_num_regions, num_heads, PADDED_token_length), a.k.a., all_generate_ids
# - all_labels: (PADDED_batch_size, PADDED_num_regions, PADDED_token_length)
# - all_metadata (PADDED_batch_size, PADDED_num_regions, ...)
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
# NOTE(xiaoke): Modified. We truncate the results based both the batch size and the number of regions
# (batch_size, PADDED_num_regions)
if all_losses is not None:
# all_losses = nested_two_dims_truncate_and_flatten(all_losses, all_batch_num_regions_shape, num_samples)
all_losses = all_losses[:num_samples]
if all_preds is not None:
all_preds = nested_two_dims_truncate_and_flatten(all_preds, all_batch_num_regions_shape, num_samples)
if all_labels is not None:
all_labels = nested_two_dims_truncate_and_flatten(all_labels, all_batch_num_regions_shape, num_samples)
if all_inputs is not None:
all_inputs = nested_two_dims_truncate_and_flatten(all_inputs, all_batch_num_regions_shape, num_samples)
# NOTE(xiaoke): Modified. We need to save the inputs for ids
if all_metadata is not None:
all_metadata = nested_two_dims_truncate_and_flatten(all_metadata, all_batch_num_regions_shape, num_samples)
# Metrics!
metrics = {}
if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
# NOTE(xiaoke): Modified. Skip the metric computation
return InferenceLoopOutput(
logits=all_preds,
label_ids=all_labels,
metadata=all_metadata,
batch_num_regions_shape=all_batch_num_regions_shape,
metrics=metrics,
num_samples=num_samples,
)
def inference_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
skip_predcition_loss_after_generate: Optional[bool] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
# NOTE: `prediction_loss_only` is always False in inference
if not self.args.predict_with_generate or prediction_loss_only:
# TODO(xiaoke): replace `super().inference_step` with batch-region `region_caption_prediction_step`
# we need `batch_num_regions_shape` to truncate the results
# remember to add a todo in loss computation, as the mask loss is not added!
loss, logits, labels = super(Seq2SeqTrainer, self).prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
batch_num_regions_shape = torch.tensor(inputs["input_ids"].shape[:2]).unsqueeze(0).to(device=loss.device)
return loss, logits, batch_num_regions_shape, labels
# return super().prediction_step(
# model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
# )
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()
gen_kwargs = self._gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
)
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
gen_kwargs["synced_gpus"] = (
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)
# If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
# (otherwise, it would continue generating from the padded `decoder_input_ids`)
if (
"labels" in inputs
and "decoder_input_ids" in inputs
and inputs["labels"].shape == inputs["decoder_input_ids"].shape
):
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
# TODO(xiaoke): the generate should return both the generated tokens and the masks
# We need to change both this `*_step` and the `*_loop`
# FIXME(xiaoke): the genearte is not warpped by self.compute_loss_context_manager()
# which could cause problem in sharded distributed inference. The `prediction_step` used in `*_loop` is affected too.
# NOTE(xiaoke): Modified. Adapt for region caption task and chunk inference to reduce memory consumption.
inputs = self._prepare_input_dtype(inputs, self.model.dtype) # NOTE: for fp16 inference
generated_outputs = self._generate_in_inference_step(inputs, gen_kwargs)
generated_tokens = generated_outputs.sequences
iou_scores = generated_outputs.iou_scores
pred_masks = generated_outputs.pred_masks
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is
# removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
if self.model.generation_config._from_model_config:
self.model.generation_config._from_model_config = False
# Retrieves GenerationConfig from model.generation_config
gen_config = self.model.generation_config
# in case the batch is shorter than max length, the output should be padded
# NOTE(xiaoke): Modified. For region caption task, the shape of the generated tokens
# is (batch_size, num_regions, num_heads, token_max_length), we use the modified `_pad_tensors_to_max_len`
if generated_tokens.shape[-1] < gen_config.max_length:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
# NOTE: Compute loss after generate
with torch.no_grad():
if has_labels and skip_predcition_loss_after_generate is not True:
with self.compute_loss_context_manager():
outputs = model(**inputs)
if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
else:
loss = None
# NOTE(xiaoke): Modified. We record the batch size and num_regions
# for truncation of distributed evaluation.
batch_num_regions_shape = torch.tensor(generated_tokens.shape[:2]).unsqueeze(0).to(generated_tokens)
if self.args.prediction_loss_only:
return loss, None, batch_num_regions_shape, None
if has_labels:
labels = inputs["labels"]
# NOTE(xiaoke): Modified. For region caption task, the shape of
# the labels is (batch_size, num_regions, token_max_length)
if labels.shape[-1] < gen_config.max_length:
labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
else:
labels = None
# TODO(xiaoke): compute and return the metrics for vision tasks here.
# the dense logits have shape of (batch_size, num_regions, num_heads, ...)
# which is very memory consuming. e.g., 230k*3*256*256=45GB
logits = dict(generated_tokens=generated_tokens, iou_scores=iou_scores)
return loss, logits, batch_num_regions_shape, labels
PROMPT_TYPES_TO_ABLATE_ON_VG = ["center_point_in_box", "random_point_in_box", "random_point_in_mask", None]
SAM_IMAGE_PROCESSOR = None
def _generate_in_inference_step(self, inputs, gen_kwargs):
prompt_types_to_ablate_on_vg = getattr(self.args, "prompt_types_to_ablate_on_vg", None)
if prompt_types_to_ablate_on_vg not in self.PROMPT_TYPES_TO_ABLATE_ON_VG:
raise ValueError(
f"prompt_types_to_ablate_on_vg is {prompt_types_to_ablate_on_vg}. It should be one of {self.PROMPT_TYPES_TO_ABLATE_ON_VG}"
)
if prompt_types_to_ablate_on_vg == "center_point_in_box":
logger.debug("prompt types is [center_point_in_box] to ablate on VG")
input_boxes = inputs["input_boxes"]
center_points_x = input_boxes[:, :, [0, 2]].mean(dim=-1)
center_points_y = input_boxes[:, :, [1, 3]].mean(dim=-1)
center_points = torch.stack((center_points_x, center_points_y), dim=-1)
center_points = center_points.unsqueeze(-2)
inputs["input_points"] = center_points
inputs["input_boxes"] = None
elif prompt_types_to_ablate_on_vg == "random_point_in_box":
logger.debug("prompt types is [random_point_in_box] to ablate on VG")
input_boxes = inputs["input_boxes"]
# NOTE: Uniformly sample a point in the box, the shape of the box is (batch_size, num_regions, 4), the coordinate are xyxy.
# NOTE: the shape of the point is (batch_size, num_regions, 1, 2), the coordinates are xy.
random_points = torch.rand(input_boxes.shape[:2] + (2,), device=input_boxes.device)
# NOTE: the shape of the point is (batch_size, num_regions, 2)
random_points = input_boxes[:, :, [0, 1]] + random_points * (
input_boxes[:, :, [2, 3]] - input_boxes[:, :, [0, 1]]
)
random_points = random_points.unsqueeze(-2)
inputs["input_points"] = random_points
inputs["input_boxes"] = None
elif prompt_types_to_ablate_on_vg == "random_point_in_mask":
logger.debug("prompt types is [random_point_in_mask] to ablate on VG")
if self.SAM_IMAGE_PROCESSOR is None:
from src.models.sam.image_processing_sam import SamImageProcessor
self.SAM_IMAGE_PROCESSOR = SamImageProcessor()
# NOTE: generate the binary mask
generated_outputs = self.model.generate(
generate_chunk_size=getattr(self.args, "generate_chunk_size"), **inputs, **gen_kwargs
)
iou_scores = generated_outputs.iou_scores # (batch_size, num_regions, 3)
iou_scores_max_head = iou_scores.argmax(dim=-1) # (batch_size, num_regions)
pred_masks = generated_outputs.pred_masks # (batch_size, num_regions, 3, H, W)
# NOTE: A list of binary masks, List[torch.Tensor]]: list shape (batch_size), bool tensor shape (num_regions, num_heads, H, W)
masks = self.SAM_IMAGE_PROCESSOR.post_process_masks(
pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
)
# NOTE: Sample random point in each masks
input_boxes = inputs["input_boxes"]
dtype = input_boxes.dtype
center_points_x = input_boxes[:, :, [0, 2]].mean(dim=-1)
center_points_y = input_boxes[:, :, [1, 3]].mean(dim=-1)
center_points = torch.stack((center_points_x, center_points_y), dim=-1)
random_points = []
for batch_idx, batch_masks in enumerate(masks):
resized_scale = inputs["reshaped_input_sizes"][batch_idx] / inputs["original_sizes"][batch_idx]
batch_iou_scores_max_head = iou_scores_max_head[batch_idx] # (num_regions)
batch_masks # (num_regions, num_heads, H, W)
max_indices = batch_iou_scores_max_head.view(-1, 1, 1, 1).expand(
-1, 1, batch_masks.size(2), batch_masks.size(3)
)
# NOTE: gather do not support multi-dim indexing, so we need to flatten the first dim
max_confidence_masks = batch_masks.gather(1, max_indices).squeeze(1)
# NOTE: for debug
# for i in range(len(max_confidence_masks)):
# assert torch.allclose(max_confidence_masks[i], batch_masks[i, batch_iou_scores_max_head[i]])
batch_random_points = []
for region_id, mask in enumerate(max_confidence_masks):
# NOTE: Find the indices of all True values in the mask
# NOTE: the index is yx, we need to flip it
true_indices = mask.nonzero(as_tuple=False).to(dtype=dtype) # Shape: [num_true_points, 2]
true_indices = torch.flip(true_indices, dims=[-1]) # Shape: [num_true_points, 2]
if len(true_indices) > 0:
selected_index = true_indices[torch.randint(0, len(true_indices), ())]
# NOTE: scale it as `input_boxes` and `input_points` are scaled to 1024 in the image preprocessor of SAM.
selected_index = selected_index * resized_scale
batch_random_points.append(selected_index)
else:
# In case there are no True values in the mask, append None or a placeholder
logger.error("No True values in the mask!")
batch_random_points.append(center_points[batch_idx, region_id])
batch_random_points = torch.stack(batch_random_points, dim=0)
random_points.append(batch_random_points)
random_points = torch.stack(random_points, dim=0)
random_points = random_points.unsqueeze(-2)
inputs["input_points"] = random_points
inputs["input_boxes"] = None
else:
logger.debug("prompt types is [null] to ablate on VG")
generated_outputs = self.model.generate(
generate_chunk_size=getattr(self.args, "generate_chunk_size"), **inputs, **gen_kwargs
)
return generated_outputs
# NOTE: END OF INFERENCE CODE
def _pad_tensors_to_max_len(self, tensor, max_length):
# NOTE(xiaoke): Modified. Check the shape, at least 1D
# FIXME(xiaoke): use `atleast_1d` maybe better
if len(tensor.shape) < 1:
raise ValueError("Cannot pad tensors with fewer than one dimension")
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
# NOTE(xiaoke): Modified. The original function taks tensor with shape (batch_size, number of tokens)
# However, for region caption task, the shape of the tensor is (batch_size, num_regions, num_heads, number of tokens)
# we need to pad the tensor along the last dim.
# NOTE(xiaoke): This is a GENERALIZED version of the original function
tensor_shape = tensor.shape
padded_tensor = pad_token_id * torch.ones(
(*tensor_shape[:-1], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[..., : tensor.shape[-1]] = tensor
return padded_tensor
# NOTE: START OF EVALUATION CODE
# `Seq2SeqTrainer` is mostly about **`predict_with_generate`**. We set `predict_with_generate=True` in config by default.
# The call order:
# 1. `Seq2SeqTrainer.evaluate` add generate args like `max_length` and `num_beams`
# 2. `Trainer.evaluate`. `prediction_loss_only=True` if self.compute_metrics is None, else `None` leads to `self.args.prediction_loss_only` is False,
# 3. `Seq2SeqTrainer.prediction_step`.To reduce the call stack, use `super(Seq2SeqTrainer, self).prediction_step`, which is `Trainer.prediction_step`
# 4. `Trainer.prediction_step`, due to `prediction_loss_only=True`.
# NOTE: START OF CUSTOM EVALUATION CODE WITH INFERENCE
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs,
):
# NOTE(xiaoke): Modified. Check the tokenizer and the unk_token_id first
# We do not want to encounter the error after all the predicions are generated
if self.tokenizer is None:
raise ValueError("You need to specify a tokenizer in Trainer!")
if self.tokenizer.unk_token_id is None:
raise ValueError(f"Check the tokenizer! unk_token_id is None! {self.tokenizer}")
gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
# memory metrics - must set up as early as possible
self._memory_tracker.start()
eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time()
output = self.inference_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
prediction_loss_only=True if self.compute_metrics is None else None,
# prediction_loss_only=False,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
skip_predcition_loss_after_generate=False,
)
# Metrics!
if self.compute_metrics is not None and output.logits is not None:
(
batch_num_regions,
gt_captions,
pred_captions,
metadata_with_num_regions_length,
logits_with_num_regions_length,
) = self._decode_inference_outputs(output)
num_heads = max(len(gt_captions[0]), len(pred_captions[0]))
def _repeat_and_flatten(list_, num_heads):
ret_list = []
for sub_list in list_:
sub_list += [sub_list[-1]] * (num_heads - len(sub_list))
ret_list += sub_list
return ret_list
gt_captions = _repeat_and_flatten(gt_captions, num_heads)
pred_captions = _repeat_and_flatten(pred_captions, num_heads)
if self.compute_metrics_func is not None:
# NOTE: only the world process zero evaluate the metrics
metrics = self.compute_metrics_func.compute(predictions=pred_captions, references=gt_captions)
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output.metrics.update(metrics)
# Copy from: Trainer.evaluate
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self.log(output.metrics)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
return output.metrics
# NOTE: END OF EVALUATION CODE
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
# NOTE(xiaoke): Modified. mtime are all the same when running on Azure Sigunlarity.
# On Azure AMLK8s, they work well.
super()._rotate_checkpoints(use_mtime=False, output_dir=output_dir)
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
optimizer_grouped_parameters = []
custom_param_lrs = self.args.custom_param_lrs
logger.debug(f"[Optimizer] default param ls: {self.args.learning_rate}")
optimizer_grouped_parameters += self._create_grouped_parameters(
opt_model, self.args.learning_rate, custom_param_lrs
)
for filtered_param, lr in custom_param_lrs.items():
logger.debug(f"[Optimizer] param {filtered_param} will use lr {lr}")
optimizer_grouped_parameters += self._create_grouped_parameters(
get_parameter_by_name(opt_model, filtered_param), lr
)
num_params_each_group = [len(g["params"]) for g in optimizer_grouped_parameters]
all_optimizable_params = list(filter(lambda p: p.requires_grad, opt_model.parameters()))
if sum(num_params_each_group) != len(all_optimizable_params):
raise ValueError(
f"num_params_each_group != all_optimizable_params ({sum(num_params_each_group)} vs. {len(all_optimizable_params)}), which should not happened."
)
logger.info(
f"[Optimizer] num of param groups: {len(optimizer_grouped_parameters)}, these group has {num_params_each_group} params"
)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
# NOTE: bump transformers from 4.30.2 to 4.36.2
# NOTE: deprecate fairscale's ShardedDDP, https://github.com/huggingface/transformers/pull/24825
# if self.sharded_ddp == ShardedDDPOption.SIMPLE:
# self.optimizer = OSS(
# params=optimizer_grouped_parameters,
# optim=optimizer_cls,
# **optimizer_kwargs,
# )
# else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def _create_grouped_parameters(self, opt_model, lr, filter_keys=None):
full_parameters = list(opt_model.named_parameters())
if (filter_keys is None) or len(filter_keys) == 0:
logger.debug(f"[Optimizer] no filter keys, using all {len(full_parameters)} params")
filtered_parameters = []
else:
filtered_parameters = get_parameters_names_by_keys(opt_model, filter_keys)
logger.debug(f"[Optimizer] filtered out {len(filtered_parameters)} from {len(full_parameters)} params")
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in full_parameters
if (n in decay_parameters and p.requires_grad and n not in filtered_parameters)
],
"weight_decay": self.args.weight_decay,
"lr": lr,
},
{
"params": [
p
for n, p in full_parameters
if (n not in decay_parameters and p.requires_grad and n not in filtered_parameters)
],
"weight_decay": 0.0,
"lr": lr,
},
]
return optimizer_grouped_parameters
def _get_learning_rate(self) -> List[float]:
if self.is_deepspeed_enabled:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()
except AssertionError as e:
if "need to call step" in str(e):
logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
last_lr = [g["lr"] for g in self.optimizer.param_groups]
else:
last_lr = self.lr_scheduler.get_last_lr()
if torch.is_tensor(last_lr):
last_lr = last_lr.item()
return last_lr
def _prepare_input_dtype(self, data: Union[torch.Tensor, Any], dtype) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input_dtype(v, dtype) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input_dtype(v, dtype) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self.args.device}
if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
# NLP models inputs are int/uint and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
elif torch.is_floating_point(data) or torch.is_complex(data):
kwargs.update({"dtype": dtype})
return data.to(**kwargs)
return data
# NOTE: to handel empty batch during training due to LSJ augmentation.
# We set `inputs` to None in `training_step` when the batch is empty.
# When encounter the None inputs, we keep the step
def _inner_training_loop(
self,
batch_size=None,
args=None,
resume_from_checkpoint=None,
trial=None,
ignore_keys_for_eval=None,
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
len_dataloader = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps and args.logging_steps < 1:
args.logging_steps = math.ceil(max_steps * args.logging_steps)
if args.eval_steps and args.eval_steps < 1:
args.eval_steps = math.ceil(max_steps * args.eval_steps)
if args.save_steps and args.save_steps < 1:
args.save_steps = math.ceil(max_steps * args.save_steps)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
" (torch.distributed.launch)."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
# NOTE: bump transformers from 4.30.2 to 4.36.2
# NOTE: deprecate fairscale's ShardedDDP, https://github.com/huggingface/transformers/pull/24825
# delay_optimizer_creation = (
# self.sharded_ddp is not None
# and self.sharded_ddp != ShardedDDPOption.SIMPLE
# or is_sagemaker_mp_enabled()
# or self.fsdp is not None
# )
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
# print(f'is_deepspeed_enabled:{self.is_deepspeed_enabled}')
# import time
# time.sleep(100)
if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# NOTE: bump transformers from 4.30.2 to 4.36.2
# The arguments for on_step_end are moved from `args` to `state`
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
# NOTE: bump transformers from 4.30.2 to 4.36.2
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
try:
if args.gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
else:
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
except AttributeError:
self.model.gradient_checkpointing_enable()
model = self._wrap_model(self.model_wrapped)
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
if use_accelerator_prepare:
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# deepspeed ckpt loading
if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
if skip_first_batches is None:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
" training on data already seen by your model."
)
else:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
if self.hp_name is not None and self._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.state.trial_name = self.hp_name(self._trial)
if trial is not None:
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
self.state.trial_params = hp_params(assignments)
else:
self.state.trial_params = None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
train_dataloader.sampler, RandomSampler
)
is_torch_less_than_1_11 = True
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however...
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
_ = list(train_dataloader.sampler)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
# NOTE: bump transformers from 4.30.2 to 4.36.2
# if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
# train_dataloader.sampler.set_epoch(epoch)
# elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
# train_dataloader.dataset.set_epoch(epoch)
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = parallel_loader
else:
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
for step, inputs in enumerate(epoch_iterator):
# NOTE: Modified here. We set `inputs` to None when the batch is empty. Skip this step.
if inputs is None:
logger.warning("The inputs shouldn't be None in training! Thus we skip this batch of data.")
continue
total_batched_samples += 1
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs)
if (
args.logging_nan_inf_filter
and not is_torch_tpu_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
# should this be under the accumulate context manager?
# the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered
# in accelerate
if total_batched_samples % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping
# NOTE: bump transformers from 4.30.2 to 4.36.2
# sharded_ddp for fairseq was deprecated.
# if self.do_grad_scaling:
# # Reduce gradients first for XLA
# if is_torch_tpu_available():
# gradients = xm._fetch_gradients(self.optimizer)
# xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
# # AMP: gradients need unscaling
# self.scaler.unscale_(self.optimizer)
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
# Optimizer step
optimizer_was_run = True
if is_torch_tpu_available():
# NOTE: bump transformers from 4.30.2 to 4.36.2
# sharded_ddp for fairseq was deprecated.
# if self.do_grad_scaling:
# self.scaler.step(self.optimizer)
# self.scaler.update()
# else:
xm.optimizer_step(self.optimizer)
# elif self.do_grad_scaling:
# scale_before = self.scaler.get_scale()
# self.scaler.step(self.optimizer)
# self.scaler.update()
# scale_after = self.scaler.get_scale()
# optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
if not isinstance(
self.lr_scheduler,
torch.optim.lr_scheduler.ReduceLROnPlateau,
):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if step < 0:
logger.warning(
"There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.control.should_training_stop:
break
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sur the model has been saved by process 0.
if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics(
"train",
start_time,
num_samples=num_train_samples,
num_steps=self.state.max_steps,
)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if checkpoint != self.state.best_model_checkpoint:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)
def _save_checkpoint(self, model, trial, metrics=None):
# NOTE: Temporay fix multi-node saving bugs: https://github.com/huggingface/transformers/issues/27925#issuecomment-1869331349
try:
super()._save_checkpoint(model, trial, metrics=metrics)
except FileNotFoundError:
pass
# NOTE: it is possible for partial saving which cannot be read.
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
open(os.path.join(output_dir, SAVING_FINISHED_FLAG), "a").close()
# NOTE: Fix the resume of DS optimizer + HF scheduler. https://github.com/huggingface/transformers/pull/25863/files
def _load_optimizer_and_scheduler(self, checkpoint):
if checkpoint is None:
return
if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
return
super()._load_optimizer_and_scheduler(checkpoint)
def nested_two_dims_truncate_and_flatten(tensors, batch_num_regions_shape, limits) -> List[torch.Tensor]:
# NOTE(xiaoke): Modified. In region caption task, the prediction has two batch dimensions,
# the first one is the batch size, the second one is the number of regions.
# we need to truncate the results based on both dims.
# - all_batch_num_regions_shape: (batch_steps, 2), one batch_step has a batch of data
# - all_losses: (PADDED_batch_size, PADDED_num_regions)
# - all_preds: (PADDED_batch_size, PADDED_num_regions, num_heads, PADDED_token_length), a.k.a., all_generate_ids
# - all_labels: (PADDED_batch_size, PADDED_num_regions, PADDED_token_length)
# - all_metadata (PADDED_batch_size, PADDED_num_regions, ...)
"Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_two_dims_truncate_and_flatten(t, batch_num_regions_shape, limits) for t in tensors)
if isinstance(tensors, Mapping):
return type(tensors)(
{k: nested_two_dims_truncate_and_flatten(t, batch_num_regions_shape, limits) for k, t in tensors.items()}
)
if len(batch_num_regions_shape.shape) != 2:
raise ValueError(f"batch_num_regions_shape should have two dims, got {batch_num_regions_shape.shape}")
if batch_num_regions_shape[:, 0].sum() != len(tensors):
raise ValueError(
f"batch_num_regions_shape[:, 0].sum() should be equal to the length of tensors, "
f"got {batch_num_regions_shape[:, 0].sum()} and {len(tensors)}"
)
list_tensors = []
sample_start_idx = 0
for num_samples, num_regions in batch_num_regions_shape:
tensor = tensors[sample_start_idx : sample_start_idx + num_samples, :num_regions]
tensor = tensor.reshape(-1, *tensor.shape[2:])
list_tensors.append(tensor)
sample_start_idx += num_samples
return np.concatenate(list_tensors[:limits], axis=0)
def get_parameter_by_name(model, parameter_name):
"""
Get the parameter object in a PyTorch model given its name.
Args:
model (nn.Module): The PyTorch model containing the parameter.
parameter_name (str): The name of the parameter as a string, with dot notation.
Returns:
nn.Parameter: The parameter object.
"""
parameter_name_parts = parameter_name.split(".")
parameter_obj = model
for part in parameter_name_parts:
if part == "":
continue
parameter_obj = getattr(parameter_obj, part)
return parameter_obj
def get_parameters_names_by_keys(opt_model, keys):
return [name for name, _ in opt_model.named_parameters() if any(key in name for key in keys)]