sanchit-gandhi HF staff commited on
Commit
6652985
1 Parent(s): 7350459

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -8
README.md CHANGED
@@ -56,19 +56,35 @@ This code snippet shows how to evaluate **Wav2Vec2-Large-Tedlium** on the TEDLIU
56
 
57
  ```python
58
  from datasets import load_dataset
59
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
60
  import torch
61
  from jiwer import wer
 
62
  tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")
63
- model = Wav2Vec2ForCTC.from_pretrained("sanchit-gandhi/wav2vec2-large-tedlium").to("cuda")
64
- processor = Wav2Vec2Processor.from_pretrained("sanchit-gandhi/wav2vec2-large-tedlium")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def map_to_pred(batch):
66
  input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
67
  with torch.no_grad():
68
- logits = model(input_values.to("cuda")).logits
69
- predicted_ids = torch.argmax(logits, dim=-1)
70
- transcription = processor.batch_decode(predicted_ids)
71
- batch["transcription"] = transcription
72
  return batch
 
73
  result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
74
- print("WER:", wer(result["text"], result["transcription"]))
 
56
 
57
  ```python
58
  from datasets import load_dataset
59
+ from transformers import AutoProcessor, SpeechEncoderDecoderModel
60
  import torch
61
  from jiwer import wer
62
+
63
  tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")
64
+
65
+ def filter_ds(text):
66
+ return text != "ignore_time_segment_in_scoring"
67
+
68
+ # remove samples ignored from scoring
69
+ tedlium_eval = tedlium_eval.map(filter_ds, input_columns=["text"])
70
+
71
+ model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium").to("cuda")
72
+ processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
73
+
74
+ gen_kwargs = {
75
+ "max_length": 200,
76
+ "num_beams": 5,
77
+ "length_penalty": 1.2
78
+ }
79
+
80
  def map_to_pred(batch):
81
  input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
82
  with torch.no_grad():
83
+ generated = model.generate(input_values.to("cuda"), **gen_kwargs)
84
+ decoded = processor.batch_decode(generated, skip_special_tokens=True)
85
+ batch["transcription"] = decoded[0]
 
86
  return batch
87
+
88
  result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
89
+ print("WER:", wer(result["text"], result["transcription"]))
90
+ ```