File size: 1,020 Bytes
01e34b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
## Test model

To test this model run the following code:

```python
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC
import torchaudio
import torch

ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")

model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2_tiny_random")

def load_audio(batch):
    batch["samples"], _ = torchaudio.load(batch["file"])
    return batch
    
ds = ds.map(load_audio)

input_values = torch.nn.utils.rnn.pad_sequence([torch.tensor(x[0]) for x in ds["samples"][:10]], batch_first=True)

# forward
logits = model(input_values).logits
pred_ids = torch.argmax(logits, dim=-1)

# dummy loss
dummy_labels = pred_ids.clone()
dummy_labels[dummy_labels == model.config.pad_token_id] = 1  # can't have CTC blank token in label
dummy_labels = dummy_labels[:, -(dummy_labels.shape[1] // 4):] # make sure labels are shorter to avoid "inf" loss (can still happen though...)
loss = model(input_values, labels=dummy_labels).loss
```