elgeish commited on
Commit
633e52d
1 Parent(s): c2f6f90

update model card

Browse files
Files changed (1) hide show
  1. README.md +3 -2
README.md CHANGED
@@ -26,12 +26,12 @@ from datasets import load_dataset
26
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
27
 
28
  model_name = "elgeish/wav2vec2-base-timit-asr"
29
- processor = Wav2Vec2Processor.from_pretrained(model_name, do_lower_case=True)
30
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
31
  model.eval()
32
 
33
  dataset = load_dataset("timit_asr", split="test").shuffle().select(range(10))
34
- char_translations = str.maketrans({"-": " ", ".": "", "?": ""})
35
 
36
  def prepare_example(example):
37
  example["speech"], _ = sf.read(example["file"])
@@ -47,6 +47,7 @@ with torch.no_grad():
47
  predicted_ids = torch.argmax(model(inputs.input_values).logits, dim=-1)
48
  predicted_ids[predicted_ids == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
49
  predicted_transcripts = processor.tokenizer.batch_decode(predicted_ids)
 
50
  for reference, predicted in zip(dataset["text"], predicted_transcripts):
51
  print("reference:", reference)
52
  print("predicted:", predicted)
26
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
27
 
28
  model_name = "elgeish/wav2vec2-base-timit-asr"
29
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
30
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
31
  model.eval()
32
 
33
  dataset = load_dataset("timit_asr", split="test").shuffle().select(range(10))
34
+ char_translations = str.maketrans({"-": " ", ",": "", ".": "", "?": ""})
35
 
36
  def prepare_example(example):
37
  example["speech"], _ = sf.read(example["file"])
47
  predicted_ids = torch.argmax(model(inputs.input_values).logits, dim=-1)
48
  predicted_ids[predicted_ids == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
49
  predicted_transcripts = processor.tokenizer.batch_decode(predicted_ids)
50
+
51
  for reference, predicted in zip(dataset["text"], predicted_transcripts):
52
  print("reference:", reference)
53
  print("predicted:", predicted)