from typing import Any, Dict, Union import torch from packaging import version from torch import nn from transformers import ( Trainer, is_apex_available, ) if is_apex_available(): from apex import amp if version.parse(torch.__version__) >= version.parse("1.6"): _is_native_amp_available = True from torch.cuda.amp import autocast class CTCTrainer(Trainer): def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: model.train() inputs = self._prepare_inputs(inputs) if self.use_amp: with autocast(): loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs) if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps if self.use_amp: self.scaler.scale(loss).backward() elif self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() elif self.deepspeed: self.deepspeed.backward(loss) else: loss.backward() return loss.detach()