elsayedissa commited on
Commit
dc098aa
1 Parent(s): 94065aa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -126,7 +126,7 @@ def normalize(batch):
126
  return batch
127
 
128
  def map_wer(batch):
129
- model.to(args.device)
130
  forced_decoder_ids = processor.get_decoder_prompt_ids(language = "es", task = "transcribe")
131
  inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
132
  with torch.no_grad():
@@ -138,10 +138,10 @@ def map_wer(batch):
138
  # process GOLD text
139
  processed_dataset = dataset.map(normalize)
140
  # get predictions
141
- predicted_dataset = processed_dataset.map(map_wer)
142
 
143
  # word error rate
144
- wer = wer_metric.compute(references=predicted_dataset['gold_text'], predictions=predicted_dataset['predicted_text'])
145
  wer = round(100 * wer, 2)
146
  print("WER:", wer)
147
 
 
126
  return batch
127
 
128
  def map_wer(batch):
129
+ model.to(device)
130
  forced_decoder_ids = processor.get_decoder_prompt_ids(language = "es", task = "transcribe")
131
  inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
132
  with torch.no_grad():
 
138
  # process GOLD text
139
  processed_dataset = dataset.map(normalize)
140
  # get predictions
141
+ predicted = processed_dataset.map(map_wer)
142
 
143
  # word error rate
144
+ wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
145
  wer = round(100 * wer, 2)
146
  print("WER:", wer)
147