|
import logging |
|
import os |
|
from typing import Optional, List, Dict, Union, Tuple, Any, NamedTuple, Mapping |
|
import time |
|
import math |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
import hydra |
|
from hydra.utils import instantiate |
|
from datasets import DatasetDict, load_dataset, IterableDatasetDict |
|
from omegaconf import DictConfig, OmegaConf |
|
from .data.transforms import SamCaptionerDataTransform |
|
from .data.collator import SamCaptionerDataCollator |
|
from .arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments |
|
from .models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor |
|
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
from transformers import set_seed, Seq2SeqTrainer, GenerationConfig |
|
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow |
|
from transformers.trainer import ( |
|
speed_metrics, |
|
deepspeed_init, |
|
is_torch_tpu_available, |
|
has_length, |
|
find_batch_size, |
|
nested_concat, |
|
nested_numpify, |
|
IterableDatasetShard, |
|
EvalLoopOutput, |
|
denumpify_detensorize, |
|
is_sagemaker_mp_enabled, |
|
get_parameter_names, |
|
ALL_LAYERNORM_LAYERS, |
|
Trainer, |
|
EvalPrediction, |
|
TrainerState, |
|
deepspeed_load_checkpoint, |
|
get_model_param_count, |
|
TRAINER_STATE_NAME, |
|
skip_first_batches, |
|
sys, |
|
HPSearchBackend, |
|
hp_params, |
|
RandomSampler, |
|
|
|
ParallelMode, |
|
dist, |
|
shutil, |
|
TrainOutput, |
|
PREFIX_CHECKPOINT_DIR, |
|
SCHEDULER_NAME, |
|
SCALER_NAME, |
|
reissue_pt_warnings, |
|
) |
|
from functools import wraps |
|
from collections import defaultdict |
|
|
|
try: |
|
from transformers.trainer import xm, met, pl |
|
except ImportError: |
|
pass |
|
try: |
|
from transformers.trainer import amp |
|
except ImportError: |
|
pass |
|
try: |
|
from transformers.trainer import smp_forward_backward |
|
except ImportError: |
|
pass |
|
try: |
|
from transformers.trainer import smp |
|
except ImportError: |
|
pass |
|
try: |
|
from transformers.trainer import OSS |
|
except ImportError: |
|
pass |
|
|
|
try: |
|
from transformers.trainer import ( |
|
ShardedDDPOption, |
|
nested_truncate, |
|
tqdm, |
|
DistributedSampler, |
|
) |
|
except ImportError: |
|
pass |
|
try: |
|
from transformers.trainer_callback import TrainerCallback |
|
except ImportError: |
|
pass |
|
try: |
|
from transformers.trainer_seq2seq import is_deepspeed_zero3_enabled |
|
except ImportError: |
|
pass |
|
|
|
|
|
|
|
def is_deepspeed_available(): |
|
return importlib.util.find_spec("deepspeed") is not None |
|
|
|
|
|
import importlib.util |
|
import warnings |
|
|
|
if is_deepspeed_available(): |
|
from accelerate.utils import DeepSpeedSchedulerWrapper |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SAVING_FINISHED_FLAG = "saving_finished.flag" |
|
|
|
|
|
class InferenceLoopOutput(NamedTuple): |
|
logits: Optional[Dict] |
|
label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] |
|
metadata: Optional[Dict] |
|
batch_num_regions_shape: Optional[np.ndarray] |
|
metrics: Optional[Dict[str, float]] |
|
num_samples: Optional[int] |
|
|
|
|
|
class FunctionTimers: |
|
def __init__(self): |
|
import time |
|
import numpy as np |
|
|
|
self.timers = defaultdict(list) |
|
|
|
def get_timer(self, f): |
|
@wraps(f) |
|
def _decorate(*args, **kwargs): |
|
start = time.perf_counter() |
|
ret = f(*args, **kwargs) |
|
end = time.perf_counter() |
|
if f.__name__ not in self.timers: |
|
self.timers[f.__name__] = [] |
|
self.timers[f.__name__].append((end - start) * 1000) |
|
return ret |
|
|
|
return _decorate |
|
|
|
def clear(self): |
|
for k in self.timers: |
|
self.timers[k] = [] |
|
|
|
def report(self): |
|
return {f"{k}_in_ms": np.mean(v) for k, v in self.timers.items()} |
|
|
|
|
|
class SCASeq2SeqTrainer(Seq2SeqTrainer): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
if self.args.use_legacy_prediction_loop is True: |
|
raise ValueError( |
|
f"Not support legacy `prediction loop` for {self.__class__.__name__}! " |
|
"As I do not override it for region caption task." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.function_timers = FunctionTimers() |
|
self._prepare_inputs = self.function_timers.get_timer(self._prepare_inputs) |
|
self.compute_loss = self.function_timers.get_timer(self.compute_loss) |
|
self._do_backward = self.function_timers.get_timer(self._do_backward) |
|
self.training_step = self.function_timers.get_timer(self.training_step) |
|
|
|
|
|
|
|
|
|
|
|
if self.compute_metrics is True and self.is_world_process_zero(): |
|
import evaluate |
|
|
|
self.compute_metrics_func = evaluate.load("meteor") |
|
else: |
|
self.compute_metrics_func = None |
|
|
|
|
|
if not hasattr(self, "is_fsdp_xla_enabled"): |
|
self.is_fsdp_xla_enabled = False |
|
|
|
|
|
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): |
|
if self.control.should_log: |
|
if is_torch_tpu_available(): |
|
xm.mark_step() |
|
|
|
logs: Dict[str, float] = {} |
|
|
|
|
|
tr_loss_scalar = self._nested_gather(tr_loss).mean().item() |
|
|
|
|
|
tr_loss -= tr_loss |
|
|
|
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) |
|
param_group_key = ["full"] + list(self.args.custom_param_lrs.keys()) |
|
param_group_values = self._get_learning_rate() |
|
|
|
param_group_values = [v for idx, v in enumerate(param_group_values) if idx % 2 == 0] |
|
for k, v in zip(param_group_key, param_group_values): |
|
logs[f"learning_rate/{k}"] = v |
|
|
|
self._total_loss_scalar += tr_loss_scalar |
|
self._globalstep_last_logged = self.state.global_step |
|
self.store_flos() |
|
logs.update(self.function_timers.report()) |
|
self.function_timers.clear() |
|
self.log(logs) |
|
|
|
metrics = None |
|
if self.control.should_evaluate: |
|
if isinstance(self.eval_dataset, dict): |
|
metrics = {} |
|
for eval_dataset_name, eval_dataset in self.eval_dataset.items(): |
|
dataset_metrics = self.evaluate( |
|
eval_dataset=eval_dataset, |
|
ignore_keys=ignore_keys_for_eval, |
|
metric_key_prefix=f"eval_{eval_dataset_name}", |
|
) |
|
metrics.update(dataset_metrics) |
|
|
|
metrics_loss = {k: v for k, v in metrics.items() if k.startswith("eval_") and k.endswith("_loss")} |
|
metrics["eval_loss"] = sum(metrics_loss.values()) / len(metrics_loss) |
|
else: |
|
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) |
|
self._report_to_hp_search(trial, self.state.global_step, metrics) |
|
|
|
|
|
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): |
|
metric_to_check = self.args.metric_for_best_model |
|
if not metric_to_check.startswith("eval_"): |
|
metric_to_check = f"eval_{metric_to_check}" |
|
self.lr_scheduler.step(metrics[metric_to_check]) |
|
|
|
if self.control.should_save: |
|
self._save_checkpoint(model, trial, metrics=metrics) |
|
self.control = self.callback_handler.on_save(self.args, self.state, self.control) |
|
|
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: |
|
""" |
|
Perform a training step on a batch of inputs. |
|
|
|
Subclass and override to inject custom behavior. |
|
|
|
Args: |
|
model (`nn.Module`): |
|
The model to train. |
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
|
The inputs and targets of the model. |
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
|
argument `labels`. Check your model's documentation for all accepted arguments. |
|
|
|
Return: |
|
`torch.Tensor`: The tensor with training loss on this batch. |
|
""" |
|
|
|
|
|
if inputs is None: |
|
logger.error("The inputs shouldn't be None in training! Thus we skip this batch of data.") |
|
return torch.tensor(torch.nan) |
|
|
|
model.train() |
|
inputs = self._prepare_inputs(inputs) |
|
|
|
if is_sagemaker_mp_enabled(): |
|
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) |
|
return loss_mb.reduce_mean().detach().to(self.args.device) |
|
|
|
with self.compute_loss_context_manager(): |
|
loss = self.compute_loss(model, inputs) |
|
|
|
if self.args.n_gpu > 1: |
|
loss = loss.mean() |
|
self._do_backward(loss) |
|
|
|
return loss.detach() / self.args.gradient_accumulation_steps |
|
|
|
def _do_backward(self, loss): |
|
|
|
|
|
|
|
|
|
if self.use_apex: |
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
else: |
|
self.accelerator.backward(loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference( |
|
self, |
|
inference_dataset: Optional[Dataset] = None, |
|
ignore_keys: Optional[List[str]] = None, |
|
metric_key_prefix: str = "inference", |
|
**gen_kwargs, |
|
): |
|
|
|
|
|
if self.tokenizer is None: |
|
raise ValueError("You need to specify a tokenizer in Trainer!") |
|
if self.tokenizer.unk_token_id is None: |
|
raise ValueError(f"Check the tokenizer! unk_token_id is None! {self.tokenizer}") |
|
|
|
gen_kwargs = gen_kwargs.copy() |
|
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
|
gen_kwargs["max_length"] = self.args.generation_max_length |
|
gen_kwargs["num_beams"] = ( |
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams |
|
) |
|
self._gen_kwargs = gen_kwargs |
|
|
|
|
|
self._memory_tracker.start() |
|
|
|
eval_dataloader = self.get_eval_dataloader(inference_dataset) |
|
start_time = time.time() |
|
|
|
output = self.inference_loop( |
|
eval_dataloader, |
|
description="Inference", |
|
|
|
prediction_loss_only=False, |
|
ignore_keys=ignore_keys, |
|
metric_key_prefix=metric_key_prefix, |
|
skip_predcition_loss_after_generate=True, |
|
) |
|
|
|
( |
|
batch_num_regions, |
|
gt_captions, |
|
pred_captions, |
|
metadata_with_num_regions_length, |
|
logits_with_num_regions_length, |
|
) = self._decode_inference_outputs(output) |
|
|
|
self._save_inference_json( |
|
metric_key_prefix, |
|
batch_num_regions, |
|
gt_captions, |
|
pred_captions, |
|
metadata_with_num_regions_length, |
|
logits_with_num_regions_length, |
|
) |
|
|
|
def _decode_inference_outputs(self, output): |
|
|
|
|
|
|
|
|
|
|
|
logits = output.logits |
|
label_ids = output.label_ids |
|
metadata = output.metadata |
|
|
|
|
|
generate_ids = logits.pop("generated_tokens") |
|
|
|
|
|
|
|
|
|
|
|
generate_ids = self._change_loss_token_to_unk_token(generate_ids, unk_token_id=self.tokenizer.unk_token_id) |
|
label_ids = self._change_loss_token_to_unk_token(label_ids, unk_token_id=self.tokenizer.unk_token_id) |
|
|
|
batch_num_regions, num_heads, token_max_length = generate_ids.shape |
|
generate_ids = generate_ids.reshape(batch_num_regions * num_heads, token_max_length) |
|
pred_captions = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True) |
|
|
|
if batch_num_regions != label_ids.shape[0]: |
|
raise ValueError(f"batch_num_regions {batch_num_regions} != label_ids.shape[0] {label_ids.shape[0]}") |
|
gt_captions = self.tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
|
|
pred_captions = np.array(pred_captions, dtype=object).reshape(batch_num_regions, num_heads).tolist() |
|
|
|
gt_captions = np.array(gt_captions, dtype=object).reshape(batch_num_regions, 1).tolist() |
|
|
|
metadata_with_num_regions_length = {} |
|
for k, v in metadata.items(): |
|
if len(v) != batch_num_regions: |
|
logger.warning( |
|
f"metadata {k} has length {len(v)}, but batch_num_regions is {batch_num_regions}, so skip it" |
|
) |
|
else: |
|
metadata_with_num_regions_length[k] = v.tolist() |
|
logits_with_num_regions_length = {} |
|
for k, v in logits.items(): |
|
if len(v) != batch_num_regions: |
|
logger.warning(f"logits {k} has length {len(v)}, but batch_num_regions is {batch_num_regions}") |
|
else: |
|
logits_with_num_regions_length[k] = v.tolist() |
|
|
|
return ( |
|
batch_num_regions, |
|
gt_captions, |
|
pred_captions, |
|
metadata_with_num_regions_length, |
|
logits_with_num_regions_length, |
|
) |
|
|
|
def _save_inference_json( |
|
self, |
|
metric_key_prefix, |
|
batch_num_regions, |
|
gt_captions, |
|
pred_captions, |
|
metadata_with_num_regions_length, |
|
logits_with_num_regions_length, |
|
): |
|
|
|
output_json = [] |
|
for idx in range(batch_num_regions): |
|
output_json.append( |
|
{ |
|
"_id": idx, |
|
"split": "inference", |
|
"references": gt_captions[idx], |
|
"candidates": pred_captions[idx], |
|
"metadata": {k: v[idx] for k, v in metadata_with_num_regions_length.items()}, |
|
"logits": {k: v[idx] for k, v in logits_with_num_regions_length.items()}, |
|
} |
|
) |
|
|
|
import json |
|
|
|
infer_json_dir = os.path.join(self.args.output_dir, "infer") |
|
os.makedirs(infer_json_dir, exist_ok=True) |
|
infer_json_file = os.path.join(infer_json_dir, f"infer-{metric_key_prefix}.json") |
|
|
|
if self.is_world_process_zero(): |
|
with open(infer_json_file, "w") as f: |
|
json.dump(output_json, f, indent=4) |
|
|
|
@staticmethod |
|
def _change_loss_token_to_unk_token(tokens, unk_token_id, padding_index=-100): |
|
tokens[tokens == padding_index] = unk_token_id |
|
return tokens |
|
|
|
def inference_loop( |
|
self, |
|
dataloader: DataLoader, |
|
description: str, |
|
prediction_loss_only: Optional[bool] = None, |
|
ignore_keys: Optional[List[str]] = None, |
|
metric_key_prefix: str = "eval", |
|
skip_predcition_loss_after_generate: Optional[bool] = None, |
|
) -> InferenceLoopOutput: |
|
""" |
|
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. |
|
|
|
Works both with or without labels. |
|
""" |
|
args = self.args |
|
|
|
|
|
if self.is_deepspeed_enabled and self.model_wrapped is self.model: |
|
_, _ = deepspeed_init(self, num_training_steps=0, inference=True) |
|
|
|
model = self._wrap_model(self.model, training=False, dataloader=dataloader) |
|
|
|
|
|
|
|
|
|
if not self.is_in_train: |
|
if args.fp16_full_eval: |
|
model = model.to(dtype=torch.float16, device=args.device) |
|
elif args.bf16_full_eval: |
|
model = model.to(dtype=torch.bfloat16, device=args.device) |
|
|
|
if len(self.accelerator._models) == 0 and model is self.model: |
|
model = ( |
|
self.accelerator.prepare(model) |
|
if self.is_deepspeed_enabled |
|
else self.accelerator.prepare_model(model, evaluation_mode=True) |
|
) |
|
|
|
if self.is_fsdp_enabled: |
|
self.model = model |
|
|
|
|
|
if model is not self.model: |
|
self.model_wrapped = model |
|
|
|
|
|
if self.is_deepspeed_enabled: |
|
self.deepspeed = self.model_wrapped |
|
|
|
|
|
|
|
if not self.is_in_train: |
|
if args.fp16_full_eval: |
|
model = model.to(dtype=torch.float16, device=args.device) |
|
elif args.bf16_full_eval: |
|
model = model.to(dtype=torch.bfloat16, device=args.device) |
|
|
|
batch_size = self.args.eval_batch_size |
|
|
|
logger.info(f"***** Running {description} *****") |
|
if has_length(dataloader): |
|
logger.info(f" Num examples = {self.num_examples(dataloader)}") |
|
else: |
|
logger.info(" Num examples: Unknown") |
|
logger.info(f" Batch size = {batch_size}") |
|
logger.info(f" Num examples for process ({self.args.process_index}) = {len(dataloader) * batch_size}") |
|
|
|
model.eval() |
|
|
|
self.callback_handler.eval_dataloader = dataloader |
|
|
|
eval_dataset = getattr(dataloader, "dataset", None) |
|
|
|
if is_torch_tpu_available(): |
|
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) |
|
|
|
if args.past_index >= 0: |
|
self._past = None |
|
|
|
|
|
|
|
losses_host = None |
|
preds_host = None |
|
labels_host = None |
|
inputs_host = None |
|
|
|
metadata_host = None |
|
batch_num_regions_shape_host = None |
|
|
|
|
|
all_losses = None |
|
all_preds = None |
|
all_labels = None |
|
all_inputs = None |
|
|
|
all_metadata = None |
|
all_batch_num_regions_shape = None |
|
|
|
|
|
observed_num_examples = 0 |
|
|
|
|
|
for step, inputs in enumerate(dataloader): |
|
|
|
observed_batch_size = find_batch_size(inputs) |
|
if observed_batch_size is not None: |
|
observed_num_examples += observed_batch_size |
|
|
|
if batch_size is None: |
|
batch_size = observed_batch_size |
|
|
|
|
|
metadata = None |
|
for k, v in inputs.items(): |
|
if k.startswith("metadata_") and isinstance(v, torch.Tensor): |
|
if metadata is None: |
|
metadata = {} |
|
|
|
metadata[k] = v |
|
metadata = self._prepare_input(metadata) |
|
|
|
|
|
|
|
loss, logits, batch_num_regions_shape, labels = self.inference_step( |
|
model, |
|
inputs, |
|
prediction_loss_only, |
|
ignore_keys=ignore_keys, |
|
skip_predcition_loss_after_generate=skip_predcition_loss_after_generate, |
|
) |
|
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None |
|
|
|
|
|
|
|
|
|
|
|
if is_torch_tpu_available(): |
|
xm.mark_step() |
|
|
|
|
|
if loss is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses = self._nested_gather(loss.repeat(batch_size)) |
|
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) |
|
if labels is not None: |
|
|
|
|
|
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) |
|
if inputs_decode is not None: |
|
|
|
|
|
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) |
|
inputs_decode = self._nested_gather(inputs_decode) |
|
inputs_host = ( |
|
inputs_decode |
|
if inputs_host is None |
|
else nested_concat(inputs_host, inputs_decode, padding_index=-100) |
|
) |
|
if logits is not None: |
|
|
|
|
|
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) |
|
if self.preprocess_logits_for_metrics is not None: |
|
logits = self.preprocess_logits_for_metrics(logits, labels) |
|
logits = self._nested_gather(logits) |
|
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) |
|
if labels is not None: |
|
labels = self._nested_gather(labels) |
|
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) |
|
|
|
if metadata is not None: |
|
|
|
|
|
metadata = self.accelerator.pad_across_processes(metadata, dim=1, pad_index=-100) |
|
metadata = self._nested_gather(metadata) |
|
metadata_host = ( |
|
metadata if metadata_host is None else nested_concat(metadata_host, metadata, padding_index=-100) |
|
) |
|
|
|
if batch_num_regions_shape is not None: |
|
batch_num_regions_shape = self._nested_gather(batch_num_regions_shape) |
|
batch_num_regions_shape_host = ( |
|
batch_num_regions_shape |
|
if batch_num_regions_shape_host is None |
|
else torch.concat((batch_num_regions_shape_host, batch_num_regions_shape), dim=0) |
|
) |
|
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) |
|
|
|
|
|
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: |
|
if losses_host is not None: |
|
losses = nested_numpify(losses_host) |
|
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) |
|
if preds_host is not None: |
|
logits = nested_numpify(preds_host) |
|
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) |
|
if inputs_host is not None: |
|
inputs_decode = nested_numpify(inputs_host) |
|
all_inputs = ( |
|
inputs_decode |
|
if all_inputs is None |
|
else nested_concat(all_inputs, inputs_decode, padding_index=-100) |
|
) |
|
if labels_host is not None: |
|
labels = nested_numpify(labels_host) |
|
all_labels = ( |
|
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) |
|
) |
|
|
|
if metadata_host is not None: |
|
metadata = nested_numpify(metadata_host) |
|
all_metadata = ( |
|
metadata if all_metadata is None else nested_concat(all_metadata, metadata, padding_index=-100) |
|
) |
|
|
|
if batch_num_regions_shape_host is not None: |
|
batch_num_regions_shape = nested_numpify(batch_num_regions_shape_host) |
|
all_batch_num_regions_shape = ( |
|
batch_num_regions_shape |
|
if all_batch_num_regions_shape is None |
|
else torch.concat(all_batch_num_regions_shape, batch_num_regions_shape, padding_index=-100) |
|
) |
|
|
|
|
|
losses_host, preds_host, inputs_host, labels_host = None, None, None, None |
|
|
|
metadata_host = None |
|
batch_num_regions_shape_host = None |
|
|
|
if args.past_index and hasattr(self, "_past"): |
|
|
|
delattr(self, "_past") |
|
|
|
|
|
if losses_host is not None: |
|
losses = nested_numpify(losses_host) |
|
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) |
|
if preds_host is not None: |
|
logits = nested_numpify(preds_host) |
|
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) |
|
if inputs_host is not None: |
|
inputs_decode = nested_numpify(inputs_host) |
|
all_inputs = ( |
|
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) |
|
) |
|
if labels_host is not None: |
|
labels = nested_numpify(labels_host) |
|
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) |
|
|
|
if metadata_host is not None: |
|
metadata = nested_numpify(metadata_host) |
|
all_metadata = ( |
|
metadata if all_metadata is None else nested_concat(all_metadata, metadata, padding_index=-100) |
|
) |
|
|
|
if batch_num_regions_shape_host is not None: |
|
batch_num_regions_shape = nested_numpify(batch_num_regions_shape_host) |
|
all_batch_num_regions_shape = ( |
|
batch_num_regions_shape |
|
if all_batch_num_regions_shape is None |
|
else nested_concat(all_batch_num_regions_shape, batch_num_regions_shape, padding_index=-100) |
|
) |
|
|
|
|
|
if has_length(eval_dataset): |
|
num_samples = len(eval_dataset) |
|
|
|
|
|
elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: |
|
num_samples = eval_dataset.num_examples |
|
else: |
|
if has_length(dataloader): |
|
|
|
logger.warning( |
|
f"Your dataset doesn't implement `__len__`. Use dataloader instead, Inference will not check all elements." |
|
) |
|
num_samples = self.num_examples(dataloader) |
|
else: |
|
|
|
logger.warning( |
|
f"Your dataset doesn't implement `__len__`. Use one process observed data. Inference will not check all elements." |
|
) |
|
num_samples = observed_num_examples |
|
if num_samples == 0 and observed_num_examples > 0: |
|
num_samples = observed_num_examples |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if all_losses is not None: |
|
|
|
all_losses = all_losses[:num_samples] |
|
if all_preds is not None: |
|
all_preds = nested_two_dims_truncate_and_flatten(all_preds, all_batch_num_regions_shape, num_samples) |
|
if all_labels is not None: |
|
all_labels = nested_two_dims_truncate_and_flatten(all_labels, all_batch_num_regions_shape, num_samples) |
|
if all_inputs is not None: |
|
all_inputs = nested_two_dims_truncate_and_flatten(all_inputs, all_batch_num_regions_shape, num_samples) |
|
|
|
if all_metadata is not None: |
|
all_metadata = nested_two_dims_truncate_and_flatten(all_metadata, all_batch_num_regions_shape, num_samples) |
|
|
|
|
|
metrics = {} |
|
|
|
if all_losses is not None: |
|
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() |
|
if hasattr(self, "jit_compilation_time"): |
|
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time |
|
|
|
|
|
for key in list(metrics.keys()): |
|
if not key.startswith(f"{metric_key_prefix}_"): |
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) |
|
|
|
|
|
return InferenceLoopOutput( |
|
logits=all_preds, |
|
label_ids=all_labels, |
|
metadata=all_metadata, |
|
batch_num_regions_shape=all_batch_num_regions_shape, |
|
metrics=metrics, |
|
num_samples=num_samples, |
|
) |
|
|
|
def inference_step( |
|
self, |
|
model: nn.Module, |
|
inputs: Dict[str, Union[torch.Tensor, Any]], |
|
prediction_loss_only: bool, |
|
ignore_keys: Optional[List[str]] = None, |
|
skip_predcition_loss_after_generate: Optional[bool] = None, |
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
""" |
|
Perform an evaluation step on `model` using `inputs`. |
|
|
|
Subclass and override to inject custom behavior. |
|
|
|
Args: |
|
model (`nn.Module`): |
|
The model to evaluate. |
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
|
The inputs and targets of the model. |
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
|
argument `labels`. Check your model's documentation for all accepted arguments. |
|
prediction_loss_only (`bool`): |
|
Whether or not to return the loss only. |
|
|
|
Return: |
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and |
|
labels (each being optional). |
|
""" |
|
|
|
|
|
if not self.args.predict_with_generate or prediction_loss_only: |
|
|
|
|
|
|
|
loss, logits, labels = super(Seq2SeqTrainer, self).prediction_step( |
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys |
|
) |
|
batch_num_regions_shape = torch.tensor(inputs["input_ids"].shape[:2]).unsqueeze(0).to(device=loss.device) |
|
return loss, logits, batch_num_regions_shape, labels |
|
|
|
|
|
|
|
|
|
has_labels = "labels" in inputs |
|
inputs = self._prepare_inputs(inputs) |
|
|
|
|
|
|
|
|
|
gen_kwargs = self._gen_kwargs.copy() |
|
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
|
gen_kwargs["max_length"] = self.model.config.max_length |
|
gen_kwargs["num_beams"] = ( |
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams |
|
) |
|
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False |
|
gen_kwargs["synced_gpus"] = ( |
|
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus |
|
) |
|
|
|
|
|
|
|
if ( |
|
"labels" in inputs |
|
and "decoder_input_ids" in inputs |
|
and inputs["labels"].shape == inputs["decoder_input_ids"].shape |
|
): |
|
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} |
|
|
|
|
|
|
|
|
|
|
|
inputs = self._prepare_input_dtype(inputs, self.model.dtype) |
|
generated_outputs = self._generate_in_inference_step(inputs, gen_kwargs) |
|
generated_tokens = generated_outputs.sequences |
|
iou_scores = generated_outputs.iou_scores |
|
pred_masks = generated_outputs.pred_masks |
|
|
|
|
|
|
|
|
|
if self.model.generation_config._from_model_config: |
|
self.model.generation_config._from_model_config = False |
|
|
|
|
|
gen_config = self.model.generation_config |
|
|
|
|
|
|
|
|
|
if generated_tokens.shape[-1] < gen_config.max_length: |
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) |
|
elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: |
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) |
|
|
|
|
|
with torch.no_grad(): |
|
if has_labels and skip_predcition_loss_after_generate is not True: |
|
with self.compute_loss_context_manager(): |
|
outputs = model(**inputs) |
|
if self.label_smoother is not None: |
|
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() |
|
else: |
|
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() |
|
else: |
|
loss = None |
|
|
|
|
|
|
|
batch_num_regions_shape = torch.tensor(generated_tokens.shape[:2]).unsqueeze(0).to(generated_tokens) |
|
|
|
if self.args.prediction_loss_only: |
|
return loss, None, batch_num_regions_shape, None |
|
|
|
if has_labels: |
|
labels = inputs["labels"] |
|
|
|
|
|
if labels.shape[-1] < gen_config.max_length: |
|
labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) |
|
elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: |
|
labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) |
|
else: |
|
labels = None |
|
|
|
|
|
|
|
|
|
logits = dict(generated_tokens=generated_tokens, iou_scores=iou_scores) |
|
return loss, logits, batch_num_regions_shape, labels |
|
|
|
PROMPT_TYPES_TO_ABLATE_ON_VG = ["center_point_in_box", "random_point_in_box", "random_point_in_mask", None] |
|
SAM_IMAGE_PROCESSOR = None |
|
|
|
def _generate_in_inference_step(self, inputs, gen_kwargs): |
|
prompt_types_to_ablate_on_vg = getattr(self.args, "prompt_types_to_ablate_on_vg", None) |
|
|
|
if prompt_types_to_ablate_on_vg not in self.PROMPT_TYPES_TO_ABLATE_ON_VG: |
|
raise ValueError( |
|
f"prompt_types_to_ablate_on_vg is {prompt_types_to_ablate_on_vg}. It should be one of {self.PROMPT_TYPES_TO_ABLATE_ON_VG}" |
|
) |
|
|
|
if prompt_types_to_ablate_on_vg == "center_point_in_box": |
|
logger.debug("prompt types is [center_point_in_box] to ablate on VG") |
|
input_boxes = inputs["input_boxes"] |
|
|
|
center_points_x = input_boxes[:, :, [0, 2]].mean(dim=-1) |
|
center_points_y = input_boxes[:, :, [1, 3]].mean(dim=-1) |
|
center_points = torch.stack((center_points_x, center_points_y), dim=-1) |
|
center_points = center_points.unsqueeze(-2) |
|
|
|
inputs["input_points"] = center_points |
|
inputs["input_boxes"] = None |
|
|
|
elif prompt_types_to_ablate_on_vg == "random_point_in_box": |
|
logger.debug("prompt types is [random_point_in_box] to ablate on VG") |
|
input_boxes = inputs["input_boxes"] |
|
|
|
|
|
|
|
random_points = torch.rand(input_boxes.shape[:2] + (2,), device=input_boxes.device) |
|
|
|
random_points = input_boxes[:, :, [0, 1]] + random_points * ( |
|
input_boxes[:, :, [2, 3]] - input_boxes[:, :, [0, 1]] |
|
) |
|
random_points = random_points.unsqueeze(-2) |
|
|
|
inputs["input_points"] = random_points |
|
inputs["input_boxes"] = None |
|
|
|
elif prompt_types_to_ablate_on_vg == "random_point_in_mask": |
|
logger.debug("prompt types is [random_point_in_mask] to ablate on VG") |
|
if self.SAM_IMAGE_PROCESSOR is None: |
|
from src.models.sam.image_processing_sam import SamImageProcessor |
|
|
|
self.SAM_IMAGE_PROCESSOR = SamImageProcessor() |
|
|
|
|
|
generated_outputs = self.model.generate( |
|
generate_chunk_size=getattr(self.args, "generate_chunk_size"), **inputs, **gen_kwargs |
|
) |
|
iou_scores = generated_outputs.iou_scores |
|
iou_scores_max_head = iou_scores.argmax(dim=-1) |
|
pred_masks = generated_outputs.pred_masks |
|
|
|
|
|
masks = self.SAM_IMAGE_PROCESSOR.post_process_masks( |
|
pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] |
|
) |
|
|
|
|
|
input_boxes = inputs["input_boxes"] |
|
dtype = input_boxes.dtype |
|
center_points_x = input_boxes[:, :, [0, 2]].mean(dim=-1) |
|
center_points_y = input_boxes[:, :, [1, 3]].mean(dim=-1) |
|
center_points = torch.stack((center_points_x, center_points_y), dim=-1) |
|
|
|
random_points = [] |
|
for batch_idx, batch_masks in enumerate(masks): |
|
resized_scale = inputs["reshaped_input_sizes"][batch_idx] / inputs["original_sizes"][batch_idx] |
|
|
|
batch_iou_scores_max_head = iou_scores_max_head[batch_idx] |
|
batch_masks |
|
max_indices = batch_iou_scores_max_head.view(-1, 1, 1, 1).expand( |
|
-1, 1, batch_masks.size(2), batch_masks.size(3) |
|
) |
|
|
|
max_confidence_masks = batch_masks.gather(1, max_indices).squeeze(1) |
|
|
|
|
|
|
|
|
|
|
|
batch_random_points = [] |
|
for region_id, mask in enumerate(max_confidence_masks): |
|
|
|
|
|
true_indices = mask.nonzero(as_tuple=False).to(dtype=dtype) |
|
true_indices = torch.flip(true_indices, dims=[-1]) |
|
|
|
if len(true_indices) > 0: |
|
selected_index = true_indices[torch.randint(0, len(true_indices), ())] |
|
|
|
selected_index = selected_index * resized_scale |
|
|
|
batch_random_points.append(selected_index) |
|
else: |
|
|
|
logger.error("No True values in the mask!") |
|
batch_random_points.append(center_points[batch_idx, region_id]) |
|
batch_random_points = torch.stack(batch_random_points, dim=0) |
|
random_points.append(batch_random_points) |
|
|
|
random_points = torch.stack(random_points, dim=0) |
|
random_points = random_points.unsqueeze(-2) |
|
|
|
inputs["input_points"] = random_points |
|
inputs["input_boxes"] = None |
|
|
|
else: |
|
logger.debug("prompt types is [null] to ablate on VG") |
|
|
|
generated_outputs = self.model.generate( |
|
generate_chunk_size=getattr(self.args, "generate_chunk_size"), **inputs, **gen_kwargs |
|
) |
|
return generated_outputs |
|
|
|
|
|
|
|
def _pad_tensors_to_max_len(self, tensor, max_length): |
|
|
|
|
|
if len(tensor.shape) < 1: |
|
raise ValueError("Cannot pad tensors with fewer than one dimension") |
|
|
|
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): |
|
|
|
pad_token_id = ( |
|
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
|
) |
|
else: |
|
if self.model.config.pad_token_id is not None: |
|
pad_token_id = self.model.config.pad_token_id |
|
else: |
|
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") |
|
|
|
|
|
|
|
|
|
|
|
tensor_shape = tensor.shape |
|
padded_tensor = pad_token_id * torch.ones( |
|
(*tensor_shape[:-1], max_length), dtype=tensor.dtype, device=tensor.device |
|
) |
|
padded_tensor[..., : tensor.shape[-1]] = tensor |
|
return padded_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate( |
|
self, |
|
eval_dataset: Optional[Dataset] = None, |
|
ignore_keys: Optional[List[str]] = None, |
|
metric_key_prefix: str = "eval", |
|
**gen_kwargs, |
|
): |
|
|
|
|
|
if self.tokenizer is None: |
|
raise ValueError("You need to specify a tokenizer in Trainer!") |
|
if self.tokenizer.unk_token_id is None: |
|
raise ValueError(f"Check the tokenizer! unk_token_id is None! {self.tokenizer}") |
|
|
|
gen_kwargs = gen_kwargs.copy() |
|
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
|
gen_kwargs["max_length"] = self.args.generation_max_length |
|
gen_kwargs["num_beams"] = ( |
|
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams |
|
) |
|
self._gen_kwargs = gen_kwargs |
|
|
|
|
|
self._memory_tracker.start() |
|
|
|
eval_dataloader = self.get_eval_dataloader(eval_dataset) |
|
start_time = time.time() |
|
|
|
output = self.inference_loop( |
|
eval_dataloader, |
|
description="Evaluation", |
|
|
|
prediction_loss_only=True if self.compute_metrics is None else None, |
|
|
|
ignore_keys=ignore_keys, |
|
metric_key_prefix=metric_key_prefix, |
|
skip_predcition_loss_after_generate=False, |
|
) |
|
|
|
|
|
if self.compute_metrics is not None and output.logits is not None: |
|
( |
|
batch_num_regions, |
|
gt_captions, |
|
pred_captions, |
|
metadata_with_num_regions_length, |
|
logits_with_num_regions_length, |
|
) = self._decode_inference_outputs(output) |
|
num_heads = max(len(gt_captions[0]), len(pred_captions[0])) |
|
|
|
def _repeat_and_flatten(list_, num_heads): |
|
ret_list = [] |
|
for sub_list in list_: |
|
sub_list += [sub_list[-1]] * (num_heads - len(sub_list)) |
|
ret_list += sub_list |
|
return ret_list |
|
|
|
gt_captions = _repeat_and_flatten(gt_captions, num_heads) |
|
pred_captions = _repeat_and_flatten(pred_captions, num_heads) |
|
|
|
if self.compute_metrics_func is not None: |
|
|
|
metrics = self.compute_metrics_func.compute(predictions=pred_captions, references=gt_captions) |
|
|
|
metrics = denumpify_detensorize(metrics) |
|
|
|
for key in list(metrics.keys()): |
|
if not key.startswith(f"{metric_key_prefix}_"): |
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) |
|
output.metrics.update(metrics) |
|
|
|
|
|
total_batch_size = self.args.eval_batch_size * self.args.world_size |
|
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: |
|
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] |
|
output.metrics.update( |
|
speed_metrics( |
|
metric_key_prefix, |
|
start_time, |
|
num_samples=output.num_samples, |
|
num_steps=math.ceil(output.num_samples / total_batch_size), |
|
) |
|
) |
|
|
|
self.log(output.metrics) |
|
|
|
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: |
|
|
|
xm.master_print(met.metrics_report()) |
|
|
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) |
|
|
|
self._memory_tracker.stop_and_update_metrics(output.metrics) |
|
|
|
return output.metrics |
|
|
|
|
|
|
|
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: |
|
|
|
|
|
super()._rotate_checkpoints(use_mtime=False, output_dir=output_dir) |
|
|
|
def create_optimizer(self): |
|
""" |
|
Setup the optimizer. |
|
|
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the |
|
Trainer's init through `optimizers`, or subclass and override this method in a subclass. |
|
""" |
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model |
|
|
|
if self.optimizer is None: |
|
optimizer_grouped_parameters = [] |
|
|
|
custom_param_lrs = self.args.custom_param_lrs |
|
logger.debug(f"[Optimizer] default param ls: {self.args.learning_rate}") |
|
optimizer_grouped_parameters += self._create_grouped_parameters( |
|
opt_model, self.args.learning_rate, custom_param_lrs |
|
) |
|
for filtered_param, lr in custom_param_lrs.items(): |
|
logger.debug(f"[Optimizer] param {filtered_param} will use lr {lr}") |
|
optimizer_grouped_parameters += self._create_grouped_parameters( |
|
get_parameter_by_name(opt_model, filtered_param), lr |
|
) |
|
|
|
num_params_each_group = [len(g["params"]) for g in optimizer_grouped_parameters] |
|
all_optimizable_params = list(filter(lambda p: p.requires_grad, opt_model.parameters())) |
|
if sum(num_params_each_group) != len(all_optimizable_params): |
|
raise ValueError( |
|
f"num_params_each_group != all_optimizable_params ({sum(num_params_each_group)} vs. {len(all_optimizable_params)}), which should not happened." |
|
) |
|
|
|
logger.info( |
|
f"[Optimizer] num of param groups: {len(optimizer_grouped_parameters)}, these group has {num_params_each_group} params" |
|
) |
|
|
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
|
if optimizer_cls.__name__ == "Adam8bit": |
|
import bitsandbytes |
|
|
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() |
|
|
|
skipped = 0 |
|
for module in opt_model.modules(): |
|
if isinstance(module, nn.Embedding): |
|
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) |
|
logger.info(f"skipped {module}: {skipped/2**20}M params") |
|
manager.register_module_override(module, "weight", {"optim_bits": 32}) |
|
logger.debug(f"bitsandbytes: will optimize {module} in fp32") |
|
logger.info(f"skipped: {skipped/2**20}M params") |
|
|
|
if is_sagemaker_mp_enabled(): |
|
self.optimizer = smp.DistributedOptimizer(self.optimizer) |
|
|
|
return self.optimizer |
|
|
|
def _create_grouped_parameters(self, opt_model, lr, filter_keys=None): |
|
full_parameters = list(opt_model.named_parameters()) |
|
if (filter_keys is None) or len(filter_keys) == 0: |
|
logger.debug(f"[Optimizer] no filter keys, using all {len(full_parameters)} params") |
|
filtered_parameters = [] |
|
else: |
|
filtered_parameters = get_parameters_names_by_keys(opt_model, filter_keys) |
|
logger.debug(f"[Optimizer] filtered out {len(filtered_parameters)} from {len(full_parameters)} params") |
|
|
|
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) |
|
decay_parameters = [name for name in decay_parameters if "bias" not in name] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [ |
|
p |
|
for n, p in full_parameters |
|
if (n in decay_parameters and p.requires_grad and n not in filtered_parameters) |
|
], |
|
"weight_decay": self.args.weight_decay, |
|
"lr": lr, |
|
}, |
|
{ |
|
"params": [ |
|
p |
|
for n, p in full_parameters |
|
if (n not in decay_parameters and p.requires_grad and n not in filtered_parameters) |
|
], |
|
"weight_decay": 0.0, |
|
"lr": lr, |
|
}, |
|
] |
|
|
|
return optimizer_grouped_parameters |
|
|
|
def _get_learning_rate(self) -> List[float]: |
|
if self.is_deepspeed_enabled: |
|
|
|
|
|
|
|
try: |
|
last_lr = self.lr_scheduler.get_last_lr() |
|
except AssertionError as e: |
|
if "need to call step" in str(e): |
|
logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0") |
|
last_lr = 0 |
|
else: |
|
raise |
|
else: |
|
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): |
|
last_lr = [g["lr"] for g in self.optimizer.param_groups] |
|
else: |
|
last_lr = self.lr_scheduler.get_last_lr() |
|
if torch.is_tensor(last_lr): |
|
last_lr = last_lr.item() |
|
return last_lr |
|
|
|
def _prepare_input_dtype(self, data: Union[torch.Tensor, Any], dtype) -> Union[torch.Tensor, Any]: |
|
""" |
|
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. |
|
""" |
|
if isinstance(data, Mapping): |
|
return type(data)({k: self._prepare_input_dtype(v, dtype) for k, v in data.items()}) |
|
elif isinstance(data, (tuple, list)): |
|
return type(data)(self._prepare_input_dtype(v, dtype) for v in data) |
|
elif isinstance(data, torch.Tensor): |
|
kwargs = {"device": self.args.device} |
|
if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): |
|
|
|
|
|
|
|
kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) |
|
elif torch.is_floating_point(data) or torch.is_complex(data): |
|
kwargs.update({"dtype": dtype}) |
|
return data.to(**kwargs) |
|
return data |
|
|
|
|
|
|
|
|
|
def _inner_training_loop( |
|
self, |
|
batch_size=None, |
|
args=None, |
|
resume_from_checkpoint=None, |
|
trial=None, |
|
ignore_keys_for_eval=None, |
|
): |
|
self.accelerator.free_memory() |
|
self._train_batch_size = batch_size |
|
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") |
|
|
|
train_dataloader = self.get_train_dataloader() |
|
|
|
|
|
|
|
|
|
|
|
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size |
|
|
|
len_dataloader = None |
|
if has_length(train_dataloader): |
|
len_dataloader = len(train_dataloader) |
|
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps |
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) |
|
num_examples = self.num_examples(train_dataloader) |
|
if args.max_steps > 0: |
|
max_steps = args.max_steps |
|
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( |
|
args.max_steps % num_update_steps_per_epoch > 0 |
|
) |
|
|
|
|
|
num_train_samples = args.max_steps * total_train_batch_size |
|
else: |
|
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
|
num_train_epochs = math.ceil(args.num_train_epochs) |
|
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs |
|
elif args.max_steps > 0: |
|
max_steps = args.max_steps |
|
|
|
num_train_epochs = sys.maxsize |
|
num_update_steps_per_epoch = max_steps |
|
num_examples = total_train_batch_size * args.max_steps |
|
num_train_samples = args.max_steps * total_train_batch_size |
|
else: |
|
raise ValueError( |
|
"args.max_steps must be set to a positive value if dataloader does not have a length, was" |
|
f" {args.max_steps}" |
|
) |
|
|
|
|
|
if args.logging_steps and args.logging_steps < 1: |
|
args.logging_steps = math.ceil(max_steps * args.logging_steps) |
|
if args.eval_steps and args.eval_steps < 1: |
|
args.eval_steps = math.ceil(max_steps * args.eval_steps) |
|
if args.save_steps and args.save_steps < 1: |
|
args.save_steps = math.ceil(max_steps * args.save_steps) |
|
|
|
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: |
|
if self.args.n_gpu > 1: |
|
|
|
|
|
raise ValueError( |
|
"Currently --debug underflow_overflow is not supported under DP. Please use DDP" |
|
" (torch.distributed.launch)." |
|
) |
|
else: |
|
debug_overflow = DebugUnderflowOverflow(self.model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled |
|
|
|
|
|
|
|
if self.is_deepspeed_enabled: |
|
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) |
|
|
|
if not delay_optimizer_creation: |
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
|
self.state = TrainerState() |
|
self.state.is_hyper_param_search = trial is not None |
|
|
|
|
|
self.state.train_batch_size = self._train_batch_size |
|
|
|
|
|
if args.logging_steps is not None: |
|
if args.logging_steps < 1: |
|
self.state.logging_steps = math.ceil(max_steps * args.logging_steps) |
|
else: |
|
self.state.logging_steps = args.logging_steps |
|
if args.eval_steps is not None: |
|
if args.eval_steps < 1: |
|
self.state.eval_steps = math.ceil(max_steps * args.eval_steps) |
|
else: |
|
self.state.eval_steps = args.eval_steps |
|
if args.save_steps is not None: |
|
if args.save_steps < 1: |
|
self.state.save_steps = math.ceil(max_steps * args.save_steps) |
|
else: |
|
self.state.save_steps = args.save_steps |
|
|
|
|
|
|
|
if args.gradient_checkpointing: |
|
try: |
|
if args.gradient_checkpointing_kwargs is None: |
|
gradient_checkpointing_kwargs = {} |
|
else: |
|
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs |
|
|
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) |
|
except AttributeError: |
|
self.model.gradient_checkpointing_enable() |
|
|
|
model = self._wrap_model(self.model_wrapped) |
|
|
|
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: |
|
self._load_from_checkpoint(resume_from_checkpoint, model) |
|
|
|
|
|
|
|
|
|
use_accelerator_prepare = True if model is self.model else False |
|
|
|
if delay_optimizer_creation: |
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
|
|
|
if use_accelerator_prepare: |
|
if hasattr(self.lr_scheduler, "step"): |
|
if self.use_apex: |
|
model = self.accelerator.prepare(self.model) |
|
else: |
|
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) |
|
else: |
|
|
|
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( |
|
self.model, self.optimizer, self.lr_scheduler |
|
) |
|
|
|
if self.is_fsdp_enabled: |
|
self.model = model |
|
|
|
|
|
if model is not self.model: |
|
self.model_wrapped = model |
|
|
|
|
|
if self.is_deepspeed_enabled: |
|
self.deepspeed = self.model_wrapped |
|
|
|
|
|
if resume_from_checkpoint is not None and self.is_deepspeed_enabled: |
|
|
|
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) |
|
|
|
|
|
self._load_optimizer_and_scheduler(resume_from_checkpoint) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("***** Running training *****") |
|
logger.info(f" Num examples = {num_examples:,}") |
|
logger.info(f" Num Epochs = {num_train_epochs:,}") |
|
logger.info(f" Instantaneous batch size per device = {self._train_batch_size:,}") |
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") |
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
|
logger.info(f" Total optimization steps = {max_steps:,}") |
|
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") |
|
|
|
self.state.epoch = 0 |
|
start_time = time.time() |
|
epochs_trained = 0 |
|
steps_trained_in_current_epoch = 0 |
|
steps_trained_progress_bar = None |
|
|
|
|
|
if resume_from_checkpoint is not None and os.path.isfile( |
|
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) |
|
): |
|
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch |
|
if not args.ignore_data_skip: |
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) |
|
steps_trained_in_current_epoch *= args.gradient_accumulation_steps |
|
else: |
|
steps_trained_in_current_epoch = 0 |
|
|
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
|
logger.info(f" Continuing training from epoch {epochs_trained}") |
|
logger.info(f" Continuing training from global step {self.state.global_step}") |
|
if not args.ignore_data_skip: |
|
if skip_first_batches is None: |
|
logger.info( |
|
f" Will skip the first {epochs_trained} epochs then the first" |
|
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," |
|
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can" |
|
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the" |
|
" training on data already seen by your model." |
|
) |
|
else: |
|
logger.info( |
|
f" Will skip the first {epochs_trained} epochs then the first" |
|
f" {steps_trained_in_current_epoch} batches in the first epoch." |
|
) |
|
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: |
|
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) |
|
steps_trained_progress_bar.set_description("Skipping the first batches") |
|
|
|
|
|
self.callback_handler.model = self.model |
|
self.callback_handler.optimizer = self.optimizer |
|
self.callback_handler.lr_scheduler = self.lr_scheduler |
|
self.callback_handler.train_dataloader = train_dataloader |
|
if self.hp_name is not None and self._trial is not None: |
|
|
|
|
|
self.state.trial_name = self.hp_name(self._trial) |
|
if trial is not None: |
|
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial |
|
self.state.trial_params = hp_params(assignments) |
|
else: |
|
self.state.trial_params = None |
|
|
|
|
|
self.state.max_steps = max_steps |
|
self.state.num_train_epochs = num_train_epochs |
|
self.state.is_local_process_zero = self.is_local_process_zero() |
|
self.state.is_world_process_zero = self.is_world_process_zero() |
|
|
|
|
|
tr_loss = torch.tensor(0.0).to(args.device) |
|
|
|
self._total_loss_scalar = 0.0 |
|
self._globalstep_last_logged = self.state.global_step |
|
model.zero_grad() |
|
|
|
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) |
|
|
|
|
|
if not args.ignore_data_skip: |
|
for epoch in range(epochs_trained): |
|
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( |
|
train_dataloader.sampler, RandomSampler |
|
) |
|
is_torch_less_than_1_11 = True |
|
if is_torch_less_than_1_11 or not is_random_sampler: |
|
|
|
|
|
for _ in train_dataloader: |
|
break |
|
else: |
|
|
|
|
|
_ = list(train_dataloader.sampler) |
|
|
|
total_batched_samples = 0 |
|
for epoch in range(epochs_trained, num_train_epochs): |
|
|
|
|
|
|
|
|
|
|
|
epoch_iterator = train_dataloader |
|
if hasattr(epoch_iterator, "set_epoch"): |
|
epoch_iterator.set_epoch(epoch) |
|
|
|
if is_torch_tpu_available(): |
|
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) |
|
epoch_iterator = parallel_loader |
|
else: |
|
epoch_iterator = train_dataloader |
|
|
|
|
|
if args.past_index >= 0: |
|
self._past = None |
|
|
|
steps_in_epoch = ( |
|
len(epoch_iterator) |
|
if len_dataloader is not None |
|
else args.max_steps * args.gradient_accumulation_steps |
|
) |
|
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) |
|
|
|
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: |
|
self._load_rng_state(resume_from_checkpoint) |
|
|
|
rng_to_sync = False |
|
steps_skipped = 0 |
|
if skip_first_batches is not None and steps_trained_in_current_epoch > 0: |
|
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) |
|
steps_skipped = steps_trained_in_current_epoch |
|
steps_trained_in_current_epoch = 0 |
|
rng_to_sync = True |
|
|
|
step = -1 |
|
for step, inputs in enumerate(epoch_iterator): |
|
|
|
if inputs is None: |
|
logger.warning("The inputs shouldn't be None in training! Thus we skip this batch of data.") |
|
continue |
|
|
|
total_batched_samples += 1 |
|
if rng_to_sync: |
|
self._load_rng_state(resume_from_checkpoint) |
|
rng_to_sync = False |
|
|
|
|
|
if steps_trained_in_current_epoch > 0: |
|
steps_trained_in_current_epoch -= 1 |
|
if steps_trained_progress_bar is not None: |
|
steps_trained_progress_bar.update(1) |
|
if steps_trained_in_current_epoch == 0: |
|
self._load_rng_state(resume_from_checkpoint) |
|
continue |
|
elif steps_trained_progress_bar is not None: |
|
steps_trained_progress_bar.close() |
|
steps_trained_progress_bar = None |
|
|
|
if step % args.gradient_accumulation_steps == 0: |
|
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) |
|
|
|
with self.accelerator.accumulate(model): |
|
tr_loss_step = self.training_step(model, inputs) |
|
|
|
if ( |
|
args.logging_nan_inf_filter |
|
and not is_torch_tpu_available() |
|
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) |
|
): |
|
|
|
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) |
|
else: |
|
tr_loss += tr_loss_step |
|
|
|
self.current_flos += float(self.floating_point_ops(inputs)) |
|
|
|
|
|
|
|
|
|
if total_batched_samples % args.gradient_accumulation_steps == 0 or ( |
|
|
|
steps_in_epoch <= args.gradient_accumulation_steps |
|
and (step + 1) == steps_in_epoch |
|
): |
|
|
|
if args.max_grad_norm is not None and args.max_grad_norm > 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_sagemaker_mp_enabled() and args.fp16: |
|
self.optimizer.clip_master_grads(args.max_grad_norm) |
|
elif hasattr(self.optimizer, "clip_grad_norm"): |
|
|
|
self.optimizer.clip_grad_norm(args.max_grad_norm) |
|
elif hasattr(model, "clip_grad_norm_"): |
|
|
|
model.clip_grad_norm_(args.max_grad_norm) |
|
elif self.use_apex: |
|
|
|
nn.utils.clip_grad_norm_( |
|
amp.master_params(self.optimizer), |
|
args.max_grad_norm, |
|
) |
|
else: |
|
self.accelerator.clip_grad_norm_( |
|
model.parameters(), |
|
args.max_grad_norm, |
|
) |
|
|
|
|
|
optimizer_was_run = True |
|
if is_torch_tpu_available(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
xm.optimizer_step(self.optimizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
self.optimizer.step() |
|
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped |
|
|
|
if optimizer_was_run: |
|
|
|
if not isinstance( |
|
self.lr_scheduler, |
|
torch.optim.lr_scheduler.ReduceLROnPlateau, |
|
): |
|
self.lr_scheduler.step() |
|
|
|
model.zero_grad() |
|
self.state.global_step += 1 |
|
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch |
|
self.control = self.callback_handler.on_step_end(args, self.state, self.control) |
|
|
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) |
|
else: |
|
self.control = self.callback_handler.on_substep_end(args, self.state, self.control) |
|
|
|
if self.control.should_epoch_stop or self.control.should_training_stop: |
|
break |
|
if step < 0: |
|
logger.warning( |
|
"There seems to be not a single sample in your epoch_iterator, stopping training at step" |
|
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" |
|
f" num_steps ({max_steps}) higher than the number of available samples." |
|
) |
|
self.control.should_training_stop = True |
|
|
|
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) |
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) |
|
|
|
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: |
|
if is_torch_tpu_available(): |
|
|
|
xm.master_print(met.metrics_report()) |
|
else: |
|
logger.warning( |
|
"You enabled PyTorch/XLA debug metrics but you don't have a TPU " |
|
"configured. Check your training configuration if this is unexpected." |
|
) |
|
if self.control.should_training_stop: |
|
break |
|
|
|
if args.past_index and hasattr(self, "_past"): |
|
|
|
delattr(self, "_past") |
|
|
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
|
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: |
|
|
|
if is_torch_tpu_available(): |
|
xm.rendezvous("load_best_model_at_end") |
|
elif args.parallel_mode == ParallelMode.DISTRIBUTED: |
|
dist.barrier() |
|
elif is_sagemaker_mp_enabled(): |
|
smp.barrier() |
|
|
|
self._load_best_model() |
|
|
|
|
|
self._total_loss_scalar += tr_loss.item() |
|
train_loss = self._total_loss_scalar / self.state.global_step |
|
|
|
metrics = speed_metrics( |
|
"train", |
|
start_time, |
|
num_samples=num_train_samples, |
|
num_steps=self.state.max_steps, |
|
) |
|
self.store_flos() |
|
metrics["total_flos"] = self.state.total_flos |
|
metrics["train_loss"] = train_loss |
|
|
|
self.is_in_train = False |
|
|
|
self._memory_tracker.stop_and_update_metrics(metrics) |
|
|
|
self.log(metrics) |
|
|
|
run_dir = self._get_output_dir(trial) |
|
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) |
|
|
|
|
|
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: |
|
for checkpoint in checkpoints_sorted: |
|
if checkpoint != self.state.best_model_checkpoint: |
|
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
|
shutil.rmtree(checkpoint) |
|
|
|
self.control = self.callback_handler.on_train_end(args, self.state, self.control) |
|
|
|
return TrainOutput(self.state.global_step, train_loss, metrics) |
|
|
|
def _save_checkpoint(self, model, trial, metrics=None): |
|
|
|
try: |
|
super()._save_checkpoint(model, trial, metrics=metrics) |
|
except FileNotFoundError: |
|
pass |
|
|
|
|
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
|
run_dir = self._get_output_dir(trial=trial) |
|
output_dir = os.path.join(run_dir, checkpoint_folder) |
|
open(os.path.join(output_dir, SAVING_FINISHED_FLAG), "a").close() |
|
|
|
|
|
def _load_optimizer_and_scheduler(self, checkpoint): |
|
if checkpoint is None: |
|
return |
|
|
|
if self.is_deepspeed_enabled: |
|
|
|
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): |
|
with warnings.catch_warnings(record=True) as caught_warnings: |
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) |
|
reissue_pt_warnings(caught_warnings) |
|
return |
|
|
|
super()._load_optimizer_and_scheduler(checkpoint) |
|
|
|
|
|
def nested_two_dims_truncate_and_flatten(tensors, batch_num_regions_shape, limits) -> List[torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)." |
|
if isinstance(tensors, (list, tuple)): |
|
return type(tensors)(nested_two_dims_truncate_and_flatten(t, batch_num_regions_shape, limits) for t in tensors) |
|
if isinstance(tensors, Mapping): |
|
return type(tensors)( |
|
{k: nested_two_dims_truncate_and_flatten(t, batch_num_regions_shape, limits) for k, t in tensors.items()} |
|
) |
|
|
|
if len(batch_num_regions_shape.shape) != 2: |
|
raise ValueError(f"batch_num_regions_shape should have two dims, got {batch_num_regions_shape.shape}") |
|
if batch_num_regions_shape[:, 0].sum() != len(tensors): |
|
raise ValueError( |
|
f"batch_num_regions_shape[:, 0].sum() should be equal to the length of tensors, " |
|
f"got {batch_num_regions_shape[:, 0].sum()} and {len(tensors)}" |
|
) |
|
list_tensors = [] |
|
sample_start_idx = 0 |
|
for num_samples, num_regions in batch_num_regions_shape: |
|
tensor = tensors[sample_start_idx : sample_start_idx + num_samples, :num_regions] |
|
tensor = tensor.reshape(-1, *tensor.shape[2:]) |
|
list_tensors.append(tensor) |
|
sample_start_idx += num_samples |
|
|
|
return np.concatenate(list_tensors[:limits], axis=0) |
|
|
|
|
|
def get_parameter_by_name(model, parameter_name): |
|
""" |
|
Get the parameter object in a PyTorch model given its name. |
|
|
|
Args: |
|
model (nn.Module): The PyTorch model containing the parameter. |
|
parameter_name (str): The name of the parameter as a string, with dot notation. |
|
|
|
Returns: |
|
nn.Parameter: The parameter object. |
|
""" |
|
parameter_name_parts = parameter_name.split(".") |
|
parameter_obj = model |
|
|
|
for part in parameter_name_parts: |
|
if part == "": |
|
continue |
|
parameter_obj = getattr(parameter_obj, part) |
|
|
|
return parameter_obj |
|
|
|
|
|
def get_parameters_names_by_keys(opt_model, keys): |
|
return [name for name, _ in opt_model.named_parameters() if any(key in name for key in keys)] |
|
|