sanchit-gandhi HF staff commited on
Commit
ba9afe1
1 Parent(s): c76aba4

update eval

Browse files
Files changed (1) hide show
  1. README.md +3 -2
README.md CHANGED
@@ -292,6 +292,7 @@ model_id = "distil-whisper/distil-medium.en"
292
 
293
  # load the model + processor
294
  model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
 
295
  processor = AutoProcessor.from_pretrained(model_id)
296
 
297
  # load the dataset with streaming mode
@@ -308,7 +309,7 @@ def inference(batch):
308
  input_features = input_features.to(device, dtype=torch_dtype)
309
 
310
  # 2. Auto-regressively generate the predicted token ids
311
- pred_ids = model.generate(input_features, max_new_tokens=128, language="en", task="transcribe")
312
 
313
  # 3. Decode the token ids to the final transcription
314
  batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
@@ -336,7 +337,7 @@ print(wer)
336
  ```
337
  **Print Output:**
338
  ```
339
- 2.983685535968466
340
  ```
341
 
342
  ## Intended Use
 
292
 
293
  # load the model + processor
294
  model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
295
+ model = model.to(device)
296
  processor = AutoProcessor.from_pretrained(model_id)
297
 
298
  # load the dataset with streaming mode
 
309
  input_features = input_features.to(device, dtype=torch_dtype)
310
 
311
  # 2. Auto-regressively generate the predicted token ids
312
+ pred_ids = model.generate(input_features, max_new_tokens=128)
313
 
314
  # 3. Decode the token ids to the final transcription
315
  batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
 
337
  ```
338
  **Print Output:**
339
  ```
340
+ 3.593196832001168
341
  ```
342
 
343
  ## Intended Use