1
---
2
language: sg
3
datasets:
4
- Yves/fhnw_swiss_parliament
5
metrics:
6
- wer
7
tags:
8
- audio
9
- speech
10
- wav2vec2
11
- sg
12
- automatic-speech-recognition
13
- speech
14
- xlsr-fine-tuning-week
15
- PyTorch
16
license: apache-2.0
17
model-index:
18
- name: Yves XLSR Wav2Vec2 Large 53 Swiss German
19
  results:
20
  - task: 
21
      name: Speech Recognition
22
      type: automatic-speech-recognition
23
    dataset:
24
      name: Yves/fhnw_swiss_parliament
25
      type: Yves/fhnw_swiss_parliament
26
    metrics:
27
       - name: Test WER
28
         type: wer
29
         value: NA%
30
---
31
# wav2vec2-large-xlsr-53-swiss-german
32
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Swiss German trying to achieve satisfactory Swiss-German to German transcriptions
33
34
## Dataset
35
Detailed information about the dataset that the model has been trained and validated with is available on [Yves/fhnw_swiss_parliament](https://huggingface.co/datasets/Yves/fhnw_swiss_parliament)
36
37
## Usage
38
The model can be used directly (without a language model) as follows:
39
40
```python
41
import torch
42
import torchaudio
43
from datasets import load_dataset
44
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
45
46
test_dataset = load_dataset("Yves/fhnw_swiss_parliament", data_dir="swiss_parliament", split="validation")
47
48
processor = Wav2Vec2Processor.from_pretrained("Yves/wav2vec2-large-xlsr-53-swiss-german")
49
model = Wav2Vec2ForCTC.from_pretrained("Yves/wav2vec2-large-xlsr-53-swiss-german").cuda()
50
51
resampler = torchaudio.transforms.Resample(48_000, 16_000)
52
53
def speech_file_to_array_fn(batch):
54
    speech_array, sampling_rate = torchaudio.load(batch["path"])
55
    batch["speech"] = resampler(speech_array).squeeze().numpy()
56
    return batch
57
58
test_dataset = test_dataset.map(speech_file_to_array_fn)
59
inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
60
61
with torch.no_grad():
62
    logits = model(inputs.input_values.cuda(), attention_mask=inputs.attention_mask).logits
63
64
predicted_ids = torch.argmax(logits, dim=-1)
65
66
print("Prediction:", processor.batch_decode(predicted_ids))
67
print("Reference:", test_dataset["sentence"])
68
```
69
70
## Evaluation
71
```python
72
import torchaudio
73
from datasets import load_dataset, load_metric
74
from transformers import (
75
    Wav2Vec2ForCTC,
76
    Wav2Vec2Processor,
77
)
78
import torch
79
import re
80
import sys
81
import csv
82
83
model_name = "Yves/wav2vec2-large-xlsr-53-swiss-german"
84
device = "cuda"
85
86
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\_\²\…\˟\&\+\[\]\(\−\–\)\›\»\‹\@\«\*\ʼ\/\°\'\'\’\'̈]'
87
88
completed_iterations = 0
89
eval_batch_size = 16
90
91
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
92
processor = Wav2Vec2Processor.from_pretrained(model_name)
93
94
ds = load_dataset("Yves/fhnw_swiss_parliament", data_dir="container_0/swiss_parliament_dryrun", split="validation")
95
96
wer = load_metric("wer")
97
cer = load_metric("cer")
98
bleu = load_metric("sacrebleu")
99
100
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
101
102
def map_to_array(batch):
103
    speech, _ = torchaudio.load(batch["path"])
104
    batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
105
    batch["sampling_rate"] = resampler.new_freq
106
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
107
    return batch
108
109
ds = ds.map(map_to_array)
110
111
out_file = open('output.tsv', 'w', encoding='utf-8')
112
tsv_writer = csv.writer(out_file, delimiter='\t')
113
tsv_writer.writerow(["client_id", "reference", "prediction", "wer", "cer", "bleu"])
114
115
def map_to_pred(batch,idx):
116
    features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
117
    input_values = features.input_values.to(device)
118
    attention_mask = features.attention_mask.to(device)
119
    with torch.no_grad():
120
        logits = model(input_values, attention_mask=attention_mask).logits
121
    pred_ids = torch.argmax(logits, dim=-1)
122
    batch["predicted"] = processor.batch_decode(pred_ids)
123
    batch["target"] = batch["sentence"]
124
    if not (len(idx) <= 2 and idx[0] == 0):
125
        for x in range(0, len(idx)):
126
            temp_reference = []
127
            temp_reference.append([batch["target"][x]])
128
            tsv_writer.writerow([batch["client_id"][x], batch["target"][x], batch["predicted"][x],                   
129
            wer.compute(predictions=[batch["predicted"][x]], references=[batch["sentence"][x]]), 
130
            cer.compute(predictions=[batch["predicted"][x]], references=[batch["sentence"][x]]),
131
            bleu.compute(predictions=[batch["predicted"][x]], references=temp_reference)["score"]])
132
    return batch
133
134
result = ds.map(map_to_pred, batched=True, batch_size=eval_batch_size, with_indices=True, remove_columns=list(ds.features.keys()))
135
136
out_file.close()
137
138
target_bleu = []
139
for x in result["target"]:
140
    target_bleu.append([x])
141
   
142
print(wer.compute(predictions=result["predicted"], references=result["target"]))
143
print(cer.compute(predictions=result["predicted"], references=result["target"]))
144
print(bleu.compute(predictions=result["predicted"], references=target_bleu))
145
```
146
 
147
## Scripts
148
The script used for training can be found on Google Colab [TBD](https://huggingface.co/Yves/wav2vec2-large-xlsr-53-swiss-german)