ttop324 commited on
Commit
1958c90
1 Parent(s): 54d59f8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +101 -0
README.md CHANGED
@@ -41,10 +41,111 @@ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav
41
 
42
  ## Inference
43
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ```
45
 
46
  ## Evaluation
47
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ```
49
 
50
 
 
41
 
42
  ## Inference
43
  ```python
44
+
45
+ #usage
46
+ import torch
47
+ import torchaudio
48
+ from datasets import load_dataset
49
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
50
+
51
+
52
+ model = Wav2Vec2ForCTC.from_pretrained("wav2vec2_large_xlsr_japanese_hiragana")
53
+ processor = Wav2Vec2Processor.from_pretrained("wav2vec2_large_xlsr_japanese_hiragana")
54
+ test_dataset = load_dataset("common_voice", "ja", split="test")
55
+
56
+
57
+
58
+ # Preprocessing the datasets.
59
+ # We need to read the aduio files as arrays
60
+ def speech_file_to_array_fn(batch):
61
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
62
+ batch["speech"] = torchaudio.functional.resample(speech_array, sampling_rate, 16000)[0].numpy()
63
+ return batch
64
+
65
+
66
+ test_dataset = test_dataset.map(speech_file_to_array_fn)
67
+ inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
68
+
69
+ with torch.no_grad():
70
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
71
+
72
+ predicted_ids = torch.argmax(logits, dim=-1)
73
+
74
+ print("Prediction:", processor.batch_decode(predicted_ids))
75
+ print("Reference:", test_dataset[:2]["sentence"])
76
  ```
77
 
78
  ## Evaluation
79
  ```python
80
+
81
+
82
+ import torch
83
+ import torchaudio
84
+ from datasets import load_dataset, load_metric
85
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
86
+ import re
87
+ import pykakasi
88
+ import MeCab
89
+
90
+
91
+ wer = load_metric("wer")
92
+ cer = load_metric("cer")
93
+
94
+ model = Wav2Vec2ForCTC.from_pretrained("ttop324/wav2vec2-live-japanese").to("cuda")
95
+ processor = Wav2Vec2Processor.from_pretrained("ttop324/wav2vec2-live-japanese")
96
+ test_dataset = load_dataset("common_voice", "ja", split="test")
97
+
98
+
99
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\‘\”\�‘、。.!,・―─~「」『』\\\\※\[\]\{\}「」〇?…]'
100
+ wakati = MeCab.Tagger("-Owakati")
101
+ kakasi = pykakasi.kakasi()
102
+ kakasi.setMode("J","H") # kanji to hiragana
103
+ kakasi.setMode("K","H") # katakana to hiragana
104
+ conv = kakasi.getConverter()
105
+
106
+
107
+ FULLWIDTH_TO_HALFWIDTH = str.maketrans(
108
+ ' 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!゛#$%&()*+、ー。/:;〈=〉?@[]^_‘{|}~',
109
+ ' 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&()*+,-./:;<=>?@[]^_`{|}~',
110
+ )
111
+ def fullwidth_to_halfwidth(s):
112
+ return s.translate(FULLWIDTH_TO_HALFWIDTH)
113
+
114
+
115
+ def preprocessData(batch):
116
+ batch["sentence"] = fullwidth_to_halfwidth(batch["sentence"])
117
+ batch["sentence"] = re.sub(chars_to_ignore_regex,' ', batch["sentence"]).lower() #remove special char
118
+ batch["sentence"] = wakati.parse(batch["sentence"]) #add space
119
+ batch["sentence"] = conv.do(batch["sentence"]) #covert to hiragana
120
+ batch["sentence"] = " ".join(batch["sentence"].split())+" " #remove multiple space
121
+
122
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
123
+ batch["speech"] = torchaudio.functional.resample(speech_array, sampling_rate, 16000)[0].numpy()
124
+ return batch
125
+
126
+
127
+ test_dataset = test_dataset.map(preprocessData)
128
+
129
+
130
+
131
+ # Preprocessing the datasets.
132
+ # We need to read the aduio files as arrays
133
+ def evaluate(batch):
134
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
135
+
136
+ with torch.no_grad():
137
+ logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
138
+
139
+ pred_ids = torch.argmax(logits, dim=-1)
140
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
141
+ return batch
142
+
143
+ result = test_dataset.map(evaluate, batched=True, batch_size=8)
144
+
145
+ print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
146
+ print("CER: {:2f}".format(100 * cer.compute(predictions=result["pred_strings"], references=result["sentence"])))
147
+
148
+
149
  ```
150
 
151