TKU410410103 commited on
Commit
4a69c46
1 Parent(s): a284709

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +85 -0
README.md CHANGED
@@ -65,12 +65,97 @@ The training hyperparameters remained consistent throughout the fine-tuning proc
65
  - num_train_epochs: 30
66
  - lr_scheduler_type: linear
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ### Test results
69
  The final model was evaluated as follows:
70
 
71
  On common_voice_11_0:
72
  - WER: 27.511982%
73
  - CER: 11.699897%
 
74
  ### Framework versions
75
 
76
  - Transformers 4.39.1
 
65
  - num_train_epochs: 30
66
  - lr_scheduler_type: linear
67
 
68
+ ### How to evaluate the model
69
+
70
+ ```python
71
+ from transformers import HubertForCTC, Wav2Vec2Processor
72
+ from datasets import load_dataset
73
+ import torchaudio
74
+ import librosa
75
+ import numpy as np
76
+ import re
77
+ import MeCab
78
+ import pykakasi
79
+ from evaluate import load
80
+
81
+ model = HubertForCTC.from_pretrained('TKU410410103/hubert-base-japanese-asr')
82
+ processor = Wav2Vec2Processor.from_pretrained("TKU410410103/hubert-base-japanese-asr")
83
+
84
+ # load dataset
85
+ test_dataset = load_dataset('mozilla-foundation/common_voice_11_0', 'ja', split='test')
86
+ remove_columns = [col for col in test_dataset.column_names if col not in ['audio', 'sentence']]
87
+ test_dataset = test_dataset.remove_columns(remove_columns)
88
+
89
+ # resample
90
+ def process_waveforms(batch):
91
+ speech_arrays = []
92
+ sampling_rates = []
93
+
94
+ for audio_path in batch['audio']:
95
+ speech_array, _ = torchaudio.load(audio_path['path'])
96
+ speech_array_resampled = librosa.resample(np.asarray(speech_array[0].numpy()), orig_sr=48000, target_sr=16000)
97
+ speech_arrays.append(speech_array_resampled)
98
+ sampling_rates.append(16000)
99
+
100
+ batch["array"] = speech_arrays
101
+ batch["sampling_rate"] = sampling_rates
102
+
103
+ return batch
104
+
105
+ # hiragana
106
+ CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
107
+ "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
108
+ "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
109
+ "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
110
+ "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "'", "ʻ", "ˆ"]
111
+ chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
112
+
113
+ wakati = MeCab.Tagger("-Owakati")
114
+ kakasi = pykakasi.kakasi()
115
+ kakasi.setMode("J","H")
116
+ kakasi.setMode("K","H")
117
+ kakasi.setMode("r","Hepburn")
118
+ conv = kakasi.getConverter()
119
+
120
+ def prepare_char(batch):
121
+ batch["sentence"] = conv.do(wakati.parse(batch["sentence"]).strip())
122
+ batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
123
+ return batch
124
+
125
+
126
+ resampled_eval_dataset = test_dataset.map(process_waveforms, batched=True, batch_size=50, num_proc=4)
127
+ eval_dataset = resampled_eval_dataset.map(prepare_char, num_proc=4)
128
+
129
+ # begin the evaluation process
130
+ wer = load("wer")
131
+ cer = load("cer")
132
+
133
+ def evaluate(batch):
134
+ inputs = processor(batch["array"], sampling_rate=16_000, return_tensors="pt", padding=True)
135
+ with torch.no_grad():
136
+ logits = model(inputs.input_values.to(device), attention_mask=inputs.attention_mask.to(device)).logits
137
+ pred_ids = torch.argmax(logits, dim=-1)
138
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
139
+ return batch
140
+
141
+ columns_to_remove = [column for column in eval_dataset.column_names if column != "sentence"]
142
+ batch_size = 16
143
+ result = eval_dataset.map(evaluate, remove_columns=columns_to_remove, batched=True, batch_size=batch_size)
144
+
145
+ wer_result = wer.compute(predictions=result["pred_strings"], references=result["sentence"])
146
+ cer_result = cer.compute(predictions=result["pred_strings"], references=result["sentence"])
147
+
148
+ print("WER: {:2f}%".format(100 * wer_result))
149
+ print("CER: {:2f}%".format(100 * cer_result))
150
+ ```
151
+
152
  ### Test results
153
  The final model was evaluated as follows:
154
 
155
  On common_voice_11_0:
156
  - WER: 27.511982%
157
  - CER: 11.699897%
158
+
159
  ### Framework versions
160
 
161
  - Transformers 4.39.1