Vu Minh Chien commited on
Commit
b04a51b
·
1 Parent(s): b176118

update readme

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -98,9 +98,9 @@ test_dataset = test_dataset.map(speech_file_to_array_fn)
98
  def evaluate(batch):
99
  inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
100
  with torch.no_grad():
101
- logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
102
- pred_ids = torch.argmax(logits, dim=-1)
103
- batch["pred_strings"] = processor.batch_decode(pred_ids)
104
  return batch
105
  result = test_dataset.map(evaluate, batched=True, batch_size=8)
106
  print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
 
98
  def evaluate(batch):
99
  inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
100
  with torch.no_grad():
101
+ logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
102
+ pred_ids = torch.argmax(logits, dim=-1)
103
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
104
  return batch
105
  result = test_dataset.map(evaluate, batched=True, batch_size=8)
106
  print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))