mrm8488 commited on
Commit
62fc70b
1 Parent(s): 8b2bc20

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -77,14 +77,14 @@ from datasets import load_dataset, load_metric
77
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
78
  import re
79
 
80
- test_dataset = load_dataset("common_voice", "es", split="test")
81
  wer = load_metric("wer")
82
 
83
  processor = Wav2Vec2Processor.from_pretrained("mrm8488/wav2vec2-large-xlsr-53-ukrainian")
84
  model = Wav2Vec2ForCTC.from_pretrained("mrm8488/wav2vec2-large-xlsr-53-ukrainian")
85
  model.to("cuda")
86
 
87
- chars_to_ignore_regex = '[\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\,\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\?\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\.\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\!\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\-\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\;\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\:\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\“]'
88
  resampler = torchaudio.transforms.Resample(48_000, 16_000)
89
 
90
  # Preprocessing the datasets.
@@ -104,7 +104,7 @@ def evaluate(batch):
104
  with torch.no_grad():
105
  logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
106
 
107
- pred_ids = torch.argmax(logits, dim=-1)
108
  batch["pred_strings"] = processor.batch_decode(pred_ids)
109
  return batch
110
 
77
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
78
  import re
79
 
80
+ test_dataset = load_dataset("common_voice", "uk", split="test")
81
  wer = load_metric("wer")
82
 
83
  processor = Wav2Vec2Processor.from_pretrained("mrm8488/wav2vec2-large-xlsr-53-ukrainian")
84
  model = Wav2Vec2ForCTC.from_pretrained("mrm8488/wav2vec2-large-xlsr-53-ukrainian")
85
  model.to("cuda")
86
 
87
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'
88
  resampler = torchaudio.transforms.Resample(48_000, 16_000)
89
 
90
  # Preprocessing the datasets.
104
  with torch.no_grad():
105
  logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
106
 
107
+ pred_ids = torch.argmax(logits, dim=-1)
108
  batch["pred_strings"] = processor.batch_decode(pred_ids)
109
  return batch
110