sanaweb commited on
Commit
4c2d473
·
verified ·
1 Parent(s): 79a69d3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import librosa
4
+ from datasets import load_dataset, load_metric
5
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
+
7
+ LANG_ID = "fa"
8
+ MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-persian"
9
+ DEVICE = "cuda"
10
+
11
+ CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
12
+ "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
13
+ "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
14
+ "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
15
+ "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
16
+
17
+ test_dataset = load_dataset("common_voice", LANG_ID, split="test")
18
+
19
+ wer = load_metric("wer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py
20
+ cer = load_metric("cer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py
21
+
22
+ chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
23
+
24
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
25
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
26
+ model.to(DEVICE)
27
+
28
+ # Preprocessing the datasets.
29
+ # We need to read the audio files as arrays
30
+ def speech_file_to_array_fn(batch):
31
+ with warnings.catch_warnings():
32
+ warnings.simplefilter("ignore")
33
+ speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
34
+ batch["speech"] = speech_array
35
+ batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper()
36
+ return batch
37
+
38
+ test_dataset = test_dataset.map(speech_file_to_array_fn)
39
+
40
+ # Preprocessing the datasets.
41
+ # We need to read the audio files as arrays
42
+ def evaluate(batch):
43
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
44
+
45
+ with torch.no_grad():
46
+ logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits
47
+
48
+ pred_ids = torch.argmax(logits, dim=-1)
49
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
50
+ return batch
51
+
52
+ result = test_dataset.map(evaluate, batched=True, batch_size=8)
53
+
54
+ predictions = [x.upper() for x in result["pred_strings"]]
55
+ references = [x.upper() for x in result["sentence"]]
56
+
57
+ print(f"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
58
+ print(f"CER: {cer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")