|
|
|
import torch |
|
import torch.nn as nn |
|
import deepspeed |
|
from transformers import Trainer |
|
from transformers.trainer_pt_utils import nested_detach |
|
from transformers.utils import is_sagemaker_mp_enabled |
|
from transformers.trainer import * |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
|
|
|
class CPMTrainer(Trainer): |
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
if "labels" in inputs: |
|
labels = inputs.pop("labels") |
|
else: |
|
labels = None |
|
|
|
if not self.args.use_lora: |
|
outputs = self.model(data = inputs, use_cache=False) |
|
else: |
|
with self.model._enable_peft_forward_hooks(**inputs): |
|
outputs = self.model.base_model(data = inputs, use_cache=False) |
|
|
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
logits = outputs.logits.view(-1, |
|
self.model.config.vocab_size).contiguous() |
|
labels = labels.view(-1).long().contiguous() |
|
|
|
labels = labels.to(logits.device) |
|
loss = loss_fct(logits, labels) |
|
else: |
|
if isinstance(outputs, dict) and "loss" not in outputs: |
|
raise ValueError( |
|
"The model did not return a loss from the inputs, only the following keys: " |
|
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." |
|
) |
|
|
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] |
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
def prediction_step( |
|
self, |
|
model: nn.Module, |
|
inputs: Dict[str, Union[torch.Tensor, Any]], |
|
prediction_loss_only: bool, |
|
ignore_keys: Optional[List[str]] = None, |
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
""" |
|
Perform an evaluation step on `model` using `inputs`. |
|
|
|
Subclass and override to inject custom behavior. |
|
|
|
Args: |
|
model (`nn.Module`): |
|
The model to evaluate. |
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
|
The inputs and targets of the model. |
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
|
argument `labels`. Check your model's documentation for all accepted arguments. |
|
prediction_loss_only (`bool`): |
|
Whether or not to return the loss only. |
|
ignore_keys (`List[str]`, *optional*): |
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
|
gathering predictions. |
|
|
|
Return: |
|
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, |
|
logits and labels (each being optional). |
|
""" |
|
has_labels = ( |
|
False |
|
if len(self.label_names) == 0 |
|
else all(inputs.get(k) is not None for k in self.label_names) |
|
) |
|
|
|
|
|
|
|
return_loss = inputs.get("return_loss", None) |
|
if return_loss is None: |
|
return_loss = self.can_return_loss |
|
loss_without_labels = ( |
|
True if len(self.label_names) == 0 and return_loss else False |
|
) |
|
|
|
inputs = self._prepare_inputs(inputs) |
|
if ignore_keys is None: |
|
if hasattr(self.model, "config"): |
|
ignore_keys = getattr( |
|
self.model.config, "keys_to_ignore_at_inference", [] |
|
) |
|
else: |
|
ignore_keys = [] |
|
|
|
|
|
if has_labels or loss_without_labels: |
|
labels = nested_detach(tuple(inputs.get(name) |
|
for name in self.label_names)) |
|
if len(labels) == 1: |
|
labels = labels[0] |
|
else: |
|
labels = None |
|
|
|
with torch.no_grad(): |
|
if is_sagemaker_mp_enabled(): |
|
raw_outputs = smp_forward_only(model, inputs) |
|
if has_labels or loss_without_labels: |
|
if isinstance(raw_outputs, dict): |
|
loss_mb = raw_outputs["loss"] |
|
logits_mb = tuple( |
|
v |
|
for k, v in raw_outputs.items() |
|
if k not in ignore_keys + ["loss"] |
|
) |
|
else: |
|
loss_mb = raw_outputs[0] |
|
logits_mb = raw_outputs[1:] |
|
|
|
loss = loss_mb.reduce_mean().detach().cpu() |
|
logits = smp_nested_concat(logits_mb) |
|
else: |
|
loss = None |
|
if isinstance(raw_outputs, dict): |
|
logits_mb = tuple( |
|
v for k, v in raw_outputs.items() if k not in ignore_keys |
|
) |
|
else: |
|
logits_mb = raw_outputs |
|
logits = smp_nested_concat(logits_mb) |
|
else: |
|
if has_labels or loss_without_labels: |
|
with self.compute_loss_context_manager(): |
|
loss, outputs = self.compute_loss( |
|
model, inputs, return_outputs=True |
|
) |
|
loss = loss.mean().detach() |
|
|
|
if isinstance(outputs, dict): |
|
logits = tuple( |
|
v |
|
for k, v in outputs.items() |
|
if k not in ignore_keys + ["loss"] |
|
) |
|
else: |
|
logits = outputs[1:] |
|
else: |
|
loss = None |
|
with self.compute_loss_context_manager(): |
|
outputs = model(**inputs) |
|
if isinstance(outputs, dict): |
|
logits = tuple( |
|
v for k, v in outputs.items() if k not in ignore_keys |
|
) |
|
else: |
|
logits = outputs |
|
|
|
if self.args.past_index >= 0: |
|
self._past = outputs[self.args.past_index - 1] |
|
|
|
if prediction_loss_only: |
|
return (loss, None, None) |
|
|
|
logits = nested_detach(logits) |
|
if len(logits) == 1: |
|
logits = logits[0] |
|
|
|
return (loss, logits, labels) |
|
|
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: |
|
""" |
|
Perform a training step on a batch of inputs. |
|
|
|
Subclass and override to inject custom behavior. |
|
|
|
Args: |
|
model (`nn.Module`): |
|
The model to train. |
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
|
The inputs and targets of the model. |
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
|
argument `labels`. Check your model's documentation for all accepted arguments. |
|
|
|
Return: |
|
`torch.Tensor`: The tensor with training loss on this batch. |
|
""" |
|
model.train() |
|
inputs = self._prepare_inputs(inputs) |
|
|
|
if is_sagemaker_mp_enabled(): |
|
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) |
|
return loss_mb.reduce_mean().detach().to(self.args.device) |
|
|
|
with self.compute_loss_context_manager(): |
|
loss = self.compute_loss(model, inputs) |
|
|
|
del inputs |
|
torch.cuda.empty_cache() |
|
|
|
if self.args.n_gpu > 1: |
|
loss = loss.mean() |
|
|
|
if self.use_apex: |
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
else: |
|
self.accelerator.backward(loss) |
|
|
|
return loss.detach() / self.args.gradient_accumulation_steps |
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir |
|
os.makedirs(output_dir, exist_ok=True) |
|
logger.info(f"Saving model checkpoint to {output_dir}") |
|
|
|
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) |
|
|
|
|
|
if not isinstance(self.model, supported_classes): |
|
if state_dict is None: |
|
state_dict = self.model.state_dict() |
|
|
|
if isinstance(unwrap_model(self.model), supported_classes): |
|
unwrap_model(self.model).save_pretrained( |
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
|
) |
|
else: |
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") |
|
if self.args.save_safetensors: |
|
safetensors.torch.save_file( |
|
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} |
|
) |
|
else: |
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) |
|
else: |
|
|
|
self.model.save_pretrained( |
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
|
) |
|
|
|
if self.tokenizer is not None: |
|
self.tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|