asahi417 commited on
Commit
f65a4d4
1 Parent(s): b20891c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -2
README.md CHANGED
@@ -300,6 +300,7 @@ import torch
300
  from transformers import pipeline
301
  from datasets import load_dataset
302
  from evaluate import load
 
303
 
304
  # model config
305
  model_id = "kotoba-tech/kotoba-whisper-v1.0"
@@ -307,6 +308,7 @@ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
307
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
308
  model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
309
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
 
310
 
311
  # data config
312
  dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
@@ -326,8 +328,8 @@ pipe = pipeline(
326
  # load the dataset and sample the audio with 16kHz
327
  dataset = load_dataset(dataset_name, split="test")
328
  transcriptions = pipe(dataset['audio'])
329
- transcriptions = [i['text'].replace(" ", "") for i in transcriptions]
330
- references = [i.replace(" ", "") for i in dataset['transcription']]
331
 
332
  # compute the CER metric
333
  cer_metric = load("cer")
 
300
  from transformers import pipeline
301
  from datasets import load_dataset
302
  from evaluate import load
303
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
304
 
305
  # model config
306
  model_id = "kotoba-tech/kotoba-whisper-v1.0"
 
308
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
309
  model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
310
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
311
+ normalizer = BasicTextNormalizer()
312
 
313
  # data config
314
  dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
 
328
  # load the dataset and sample the audio with 16kHz
329
  dataset = load_dataset(dataset_name, split="test")
330
  transcriptions = pipe(dataset['audio'])
331
+ transcriptions = [normalizer(i['text']).replace(" ", "") for i in transcriptions]
332
+ references = [normalizer(i).replace(" ", "") for i in dataset['transcription']]
333
 
334
  # compute the CER metric
335
  cer_metric = load("cer")