Ryan Ford commited on
Commit
058f50a
1 Parent(s): bb82e89

updated model

Browse files
Files changed (1) hide show
  1. README.md +29 -1
README.md CHANGED
@@ -7,4 +7,32 @@ metrics:
7
  - wer
8
  license: mit
9
  ---
10
- We took `facebook/wav2vec2-large-960h` and fine tuned it using 800 audio clips (around 8-10 seconds each) from various cryptocurrency related podcasts. To label the data, we downloaded cryptocurrency podcasts from youtube with their subtitle data and split the clips up by sentence. We then compared the youtube transcription with `facebook/wav2vec2-large-960h` to correct many mistakes in the youtube transcriptions. We can probably achieve better results with more data clean up.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  - wer
8
  license: mit
9
  ---
10
+ We took `facebook/wav2vec2-large-960h` and fine tuned it using 1400 audio clips (around 10-15 seconds each) from various cryptocurrency related podcasts. To label the data, we downloaded cryptocurrency podcasts from youtube with their subtitle data and split the clips up by sentence. We then compared the youtube transcription with `facebook/wav2vec2-large-960h` to correct many mistakes in the youtube transcriptions. We can probably achieve better results with more data clean up.
11
+
12
+ On our data we achieved a WER of 13.1%. `facebook/wav2vec2-large-960h` only reached a WER of 27% on our data.
13
+
14
+ ## Usage
15
+ ```python
16
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
17
+ from datasets import load_dataset
18
+ import soundfile as sf
19
+ import torch
20
+
21
+
22
+ # load model and tokenizer
23
+ processor = Wav2Vec2Processor.from_pretrained("distractedm1nd/wav2vec-en-finetuned-on-cryptocurrency")
24
+ model = Wav2Vec2ForCTC.from_pretrained("distractedm1nd/wav2vec-en-finetuned-on-cryptocurrency"
25
+
26
+ filename = "INSERT_FILENAME"
27
+ audio, sampling_rate = sf.read(filename)
28
+
29
+ input_values = processor(audio, return_tensors="pt", padding="longest", sampling_rate=sampling_rate).input_values # Batch size 1
30
+
31
+
32
+ # retrieve logits
33
+ logits = model(input_values).logits
34
+
35
+ # take argmax and decode
36
+ predicted_ids = torch.argmax(logits, dim=-1)
37
+ tokenizer.batch_decode(predicted_ids
38
+ ```