5roop commited on
Commit
77b03e4
1 Parent(s): 67a746a

Update README.md

Browse files

Updated the example script due to newer versions backward incompatibility.

Files changed (1) hide show
  1. README.md +9 -13
README.md CHANGED
@@ -26,11 +26,12 @@ Initial evaluation on partially noisy data showed the model to achieve a word er
26
 
27
  ```python
28
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
29
- from datasets import Audio
30
  import soundfile as sf
31
  import torch
32
  import os
33
 
 
 
34
  # load model and tokenizer
35
  processor = Wav2Vec2Processor.from_pretrained(
36
  "classla/wav2vec2-xls-r-parlaspeech-hr")
@@ -38,28 +39,23 @@ model = Wav2Vec2ForCTC.from_pretrained("classla/wav2vec2-xls-r-parlaspeech-hr")
38
 
39
 
40
  # download the example wav files:
41
- os.system("curl https://huggingface.co/classla/wav2vec2-xls-r-parlaspeech-hr/raw/main/00020570a.flac.wav")
42
 
43
- # read the wav file as datasets.Audio object
44
- audio = Audio(sampling_rate=16000).decode_example("00020570a.flac.wav")
 
45
 
46
  # remove the raw wav file
47
  os.system("rm 00020570a.flac.wav")
48
 
49
- # tokenize
50
- input_values = processor(
51
- audio["array"], return_tensors="pt", padding=True,
52
- sampling_rate=16000).input_values
53
-
54
  # retrieve logits
55
- logits = model(input_values).logits
56
 
57
  # take argmax and decode
58
  predicted_ids = torch.argmax(logits, dim=-1)
59
- transcription = processor.batch_decode(predicted_ids)
60
-
61
 
62
- # transcription: ['veliki broj poslovnih subjekata posluje sa minusom velik dio']
63
  ```
64
 
65
  ## Training hyperparameters
 
26
 
27
  ```python
28
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 
29
  import soundfile as sf
30
  import torch
31
  import os
32
 
33
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+
35
  # load model and tokenizer
36
  processor = Wav2Vec2Processor.from_pretrained(
37
  "classla/wav2vec2-xls-r-parlaspeech-hr")
 
39
 
40
 
41
  # download the example wav files:
42
+ os.system("wget https://huggingface.co/classla/wav2vec2-xls-r-parlaspeech-hr/raw/main/00020570a.flac.wav")
43
 
44
+ # read the wav file
45
+ speech, sample_rate = sf.read("00020570a.flac.wav")
46
+ input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.to(device)
47
 
48
  # remove the raw wav file
49
  os.system("rm 00020570a.flac.wav")
50
 
 
 
 
 
 
51
  # retrieve logits
52
+ logits = model.to(device)(input_values).logits
53
 
54
  # take argmax and decode
55
  predicted_ids = torch.argmax(logits, dim=-1)
56
+ transcription = processor.decode(predicted_ids[0]).lower()
 
57
 
58
+ # transcription: 'veliki broj poslovnih subjekata posluje sa minusom velik dio'
59
  ```
60
 
61
  ## Training hyperparameters