SeyedAli commited on
Commit
3283fcc
1 Parent(s): 6a07d86

Upload 2 files

Browse files
Files changed (2) hide show
  1. collator.py +58 -0
  2. trainer.py +62 -0
collator.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Union
3
+ import torch
4
+
5
+ import transformers
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor
7
+
8
+
9
+ @dataclass
10
+ class DataCollatorCTCWithPadding:
11
+ """
12
+ Data collator that will dynamically pad the inputs received.
13
+ Args:
14
+ feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`)
15
+ The feature_extractor used for proccessing the data.
16
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
17
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
18
+ among:
19
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
20
+ sequence if provided).
21
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
22
+ maximum acceptable input length for the model if that argument is not provided.
23
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
24
+ different lengths).
25
+ max_length (:obj:`int`, `optional`):
26
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
27
+ max_length_labels (:obj:`int`, `optional`):
28
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
29
+ pad_to_multiple_of (:obj:`int`, `optional`):
30
+ If set will pad the sequence to a multiple of the provided value.
31
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
32
+ 7.5 (Volta).
33
+ """
34
+
35
+ feature_extractor: Wav2Vec2FeatureExtractor
36
+ padding: Union[bool, str] = True
37
+ max_length: Optional[int] = None
38
+ max_length_labels: Optional[int] = None
39
+ pad_to_multiple_of: Optional[int] = None
40
+ pad_to_multiple_of_labels: Optional[int] = None
41
+
42
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
43
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
44
+ label_features = [feature["labels"] for feature in features]
45
+
46
+ d_type = torch.long if isinstance(label_features[0], int) else torch.float
47
+
48
+ batch = self.feature_extractor.pad(
49
+ input_features,
50
+ padding=self.padding,
51
+ max_length=self.max_length,
52
+ pad_to_multiple_of=self.pad_to_multiple_of,
53
+ return_tensors="pt",
54
+ )
55
+
56
+ batch["labels"] = torch.tensor(label_features, dtype=d_type)
57
+
58
+ return batch
trainer.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Union
2
+
3
+ import torch
4
+ from packaging import version
5
+ from torch import nn
6
+
7
+ from transformers import (
8
+ Trainer,
9
+ is_apex_available,
10
+ )
11
+
12
+ if is_apex_available():
13
+ from apex import amp
14
+
15
+ if version.parse(torch.__version__) >= version.parse("1.6"):
16
+ _is_native_amp_available = True
17
+ from torch.cuda.amp import autocast
18
+
19
+
20
+ class CTCTrainer(Trainer):
21
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
22
+ """
23
+ Perform a training step on a batch of inputs.
24
+
25
+ Subclass and override to inject custom behavior.
26
+
27
+ Args:
28
+ model (:obj:`nn.Module`):
29
+ The model to train.
30
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
31
+ The inputs and targets of the model.
32
+
33
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
34
+ argument :obj:`labels`. Check your model's documentation for all accepted arguments.
35
+
36
+ Return:
37
+ :obj:`torch.Tensor`: The tensor with training loss on this batch.
38
+ """
39
+
40
+ model.train()
41
+ inputs = self._prepare_inputs(inputs)
42
+
43
+ if self.use_amp:
44
+ with autocast():
45
+ loss = self.compute_loss(model, inputs)
46
+ else:
47
+ loss = self.compute_loss(model, inputs)
48
+
49
+ if self.args.gradient_accumulation_steps > 1:
50
+ loss = loss / self.args.gradient_accumulation_steps
51
+
52
+ if self.use_amp:
53
+ self.scaler.scale(loss).backward()
54
+ elif self.use_apex:
55
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
56
+ scaled_loss.backward()
57
+ elif self.deepspeed:
58
+ self.deepspeed.backward(loss)
59
+ else:
60
+ loss.backward()
61
+
62
+ return loss.detach()