Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, Union | |
| import torch | |
| from packaging import version | |
| from torch import nn | |
| from transformers import ( | |
| Trainer, | |
| is_apex_available, | |
| ) | |
| if is_apex_available(): | |
| from apex import amp | |
| if version.parse(torch.__version__) >= version.parse("1.6"): | |
| _is_native_amp_available = True | |
| from torch.cuda.amp import autocast | |
| class CTCTrainer(Trainer): | |
| def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: | |
| model.train() | |
| inputs = self._prepare_inputs(inputs) | |
| if self.use_amp: | |
| with autocast(): | |
| loss = self.compute_loss(model, inputs) | |
| else: | |
| loss = self.compute_loss(model, inputs) | |
| if self.args.gradient_accumulation_steps > 1: | |
| loss = loss / self.args.gradient_accumulation_steps | |
| if self.use_amp: | |
| self.scaler.scale(loss).backward() | |
| elif self.use_apex: | |
| with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| elif self.deepspeed: | |
| self.deepspeed.backward(loss) | |
| else: | |
| loss.backward() | |
| return loss.detach() |