elsayedissa commited on
Commit
f3f2922
1 Parent(s): ef96ccf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -113,7 +113,7 @@ processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-arabic-5
113
  model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-arabic-5k-steps")
114
 
115
  # dataset
116
- dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test", )#cache_dir=args.cache_dir
117
  dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
118
 
119
  #for debuggings: it gets two examples
@@ -136,11 +136,11 @@ def normalize(batch):
136
  return batch
137
 
138
  def map_wer(batch):
139
- model.to(args.device)
140
- forced_decoder_ids = processor.get_decoder_prompt_ids(language = args.language, task = "transcribe")
141
  inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
142
  with torch.no_grad():
143
- generated_ids = model.generate(inputs=inputs.to(args.device), forced_decoder_ids=forced_decoder_ids)
144
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
145
  batch["predicted_text"] = clean_text(transcription)
146
  return batch
@@ -148,10 +148,10 @@ def map_wer(batch):
148
  # process GOLD text
149
  processed_dataset = dataset.map(normalize)
150
  # get predictions
151
- predicted_dataset = processed_dataset.map(map_wer)
152
 
153
  # word error rate
154
- wer = wer_metric.compute(references=predicted_dataset['gold_text'], predictions=predicted_dataset['predicted_text'])
155
  wer = round(100 * wer, 2)
156
  print("WER:", wer)
157
  ```
 
113
  model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-arabic-5k-steps")
114
 
115
  # dataset
116
+ dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test", ) #cache_dir=args.cache_dir
117
  dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
118
 
119
  #for debuggings: it gets two examples
 
136
  return batch
137
 
138
  def map_wer(batch):
139
+ model.to(device)
140
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ar", task = "transcribe")
141
  inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
142
  with torch.no_grad():
143
+ generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
144
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
145
  batch["predicted_text"] = clean_text(transcription)
146
  return batch
 
148
  # process GOLD text
149
  processed_dataset = dataset.map(normalize)
150
  # get predictions
151
+ predicted = processed_dataset.map(map_wer)
152
 
153
  # word error rate
154
+ wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
155
  wer = round(100 * wer, 2)
156
  print("WER:", wer)
157
  ```