patrickvonplaten commited on
Commit
01e34b6
1 Parent(s): 4ecf80f

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -0
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Test model
2
+
3
+ To test this model run the following code:
4
+
5
+ ```python
6
+ from datasets import load_dataset
7
+ from transformers import Wav2Vec2ForCTC
8
+ import torchaudio
9
+ import torch
10
+
11
+ ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
12
+
13
+ model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2_tiny_random")
14
+
15
+ def load_audio(batch):
16
+ batch["samples"], _ = torchaudio.load(batch["file"])
17
+ return batch
18
+
19
+ ds = ds.map(load_audio)
20
+
21
+ input_values = torch.nn.utils.rnn.pad_sequence([torch.tensor(x[0]) for x in ds["samples"][:10]], batch_first=True)
22
+
23
+ # forward
24
+ logits = model(input_values).logits
25
+ pred_ids = torch.argmax(logits, dim=-1)
26
+
27
+ # dummy loss
28
+ dummy_labels = pred_ids.clone()
29
+ dummy_labels[dummy_labels == model.config.pad_token_id] = 1 # can't have CTC blank token in label
30
+ dummy_labels = dummy_labels[:, -(dummy_labels.shape[1] // 4):] # make sure labels are shorter to avoid "inf" loss (can still happen though...)
31
+ loss = model(input_values, labels=dummy_labels).loss
32
+ ```