gvs commited on
Commit
bca47bc
1 Parent(s): a54d8f3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -20
README.md CHANGED
@@ -43,7 +43,7 @@ import torchaudio
43
  from datasets import load_dataset
44
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
45
 
46
- test_dataset = <load-test-split-of-combined-dataset> #TODO
47
 
48
  processor = Wav2Vec2Processor.from_pretrained("gvs/wav2vec2-large-xlsr-malayalam")
49
  model = Wav2Vec2ForCTC.from_pretrained("gvs/wav2vec2-large-xlsr-malayalam")
@@ -53,15 +53,15 @@ resampler = torchaudio.transforms.Resample(48_000, 16_000)
53
  # Preprocessing the datasets.
54
  # We need to read the audio files as arrays
55
  def speech_file_to_array_fn(batch):
56
- \tspeech_array, sampling_rate = torchaudio.load(batch["path"])
57
- \tbatch["speech"] = resampler(speech_array).squeeze().numpy()
58
- \treturn batch
59
 
60
  test_dataset = test_dataset.map(speech_file_to_array_fn)
61
  inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
62
 
63
  with torch.no_grad():
64
- \tlogits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
65
 
66
  predicted_ids = torch.argmax(logits, dim=-1)
67
 
@@ -81,8 +81,39 @@ import torchaudio
81
  from datasets import load_dataset, load_metric
82
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
83
  import re
84
-
85
- test_dataset = <load-test-split-of-combined-dataset> #TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  wer = load_metric("wer")
88
 
@@ -90,33 +121,33 @@ processor = Wav2Vec2Processor.from_pretrained("gvs/wav2vec2-large-xlsr-malayalam
90
  model = Wav2Vec2ForCTC.from_pretrained("gvs/wav2vec2-large-xlsr-malayalam")
91
  model.to("cuda")
92
 
93
- chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\"\\“\\%\\‘\\”\\�Utrnle\\_]'
94
- unicode_ignore_regex = r'[\\u200d\\u200c\\u200e]'
95
 
96
  resampler = torchaudio.transforms.Resample(48_000, 16_000)
97
 
98
  # Preprocessing the datasets.
99
  # We need to read the audio files as arrays
100
  def speech_file_to_array_fn(batch):
101
- \tbatch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"])
102
- batch["sentence"] = re.sub(unicode_ignore_regex, '', batch["sentence"])
103
- \tspeech_array, sampling_rate = torchaudio.load(batch["path"])
104
- \tbatch["speech"] = resampler(speech_array).squeeze().numpy()
105
- \treturn batch
106
 
107
  test_dataset = test_dataset.map(speech_file_to_array_fn)
108
 
109
  # Preprocessing the datasets.
110
  # We need to read the audio files as arrays
111
  def evaluate(batch):
112
- \tinputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
113
 
114
- \twith torch.no_grad():
115
- \t\tlogits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
116
 
117
- \tpred_ids = torch.argmax(logits, dim=-1)
118
- \tbatch["pred_strings"] = processor.batch_decode(pred_ids)
119
- \treturn batch
120
 
121
  result = test_dataset.map(evaluate, batched=True, batch_size=8)
122
 
43
  from datasets import load_dataset
44
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
45
 
46
+ test_dataset = <load-test-split-of-combined-dataset> # Details on loading this dataset in the evaluation section
47
 
48
  processor = Wav2Vec2Processor.from_pretrained("gvs/wav2vec2-large-xlsr-malayalam")
49
  model = Wav2Vec2ForCTC.from_pretrained("gvs/wav2vec2-large-xlsr-malayalam")
53
  # Preprocessing the datasets.
54
  # We need to read the audio files as arrays
55
  def speech_file_to_array_fn(batch):
56
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
57
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
58
+ return batch
59
 
60
  test_dataset = test_dataset.map(speech_file_to_array_fn)
61
  inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
62
 
63
  with torch.no_grad():
64
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
65
 
66
  predicted_ids = torch.argmax(logits, dim=-1)
67
 
81
  from datasets import load_dataset, load_metric
82
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
83
  import re
84
+ from datasets import load_dataset, load_metric
85
+ from pathlib import Path
86
+
87
+ data_dir = Path('<path-to-custom-dataset>')
88
+
89
+ dataset_folders = {
90
+ 'openslr': 'openslr',
91
+ 'indic-tts': 'indic-tts-ml',
92
+ }
93
+
94
+ # Set directories for datasets
95
+ openslr_male_dir = data_dir / dataset_folders['openslr'] / 'male'
96
+ openslr_female_dir = data_dir / dataset_folders['openslr'] / 'female'
97
+ indic_tts_male_dir = data_dir / dataset_folders['indic-tts'] / 'male'
98
+ indic_tts_female_dir = data_dir / dataset_folders['indic-tts'] / 'female'
99
+
100
+ # Load the datasets, total count is set manually
101
+ openslr_male = load_dataset("json", data_files=[f"{str(openslr_male_dir.absolute())}/sample_{i}.json" for i in range(2023)], split="train")
102
+ openslr_female = load_dataset("json", data_files=[f"{str(openslr_female_dir.absolute())}/sample_{i}.json" for i in range(2103)], split="train")
103
+ indic_tts_male = load_dataset("json", data_files=[f"{str(indic_tts_male_dir.absolute())}/sample_{i}.json" for i in range(5649)], split="train")
104
+ indic_tts_female = load_dataset("json", data_files=[f"{str(indic_tts_female_dir.absolute())}/sample_{i}.json" for i in range(2950)], split="train")
105
+
106
+ # Create test split as 20%, set random seed as well.
107
+ test_size = 0.2
108
+ random_seed=1
109
+ openslr_male_splits = openslr_male.train_test_split(test_size=test_size, seed=random_seed)
110
+ openslr_female_splits = openslr_female.train_test_split(test_size=test_size, seed=random_seed)
111
+ indic_tts_male_splits = indic_tts_male.train_test_split(test_size=test_size, seed=random_seed)
112
+ indic_tts_female_splits = indic_tts_female.train_test_split(test_size=test_size, seed=random_seed)
113
+
114
+ # Get combined test dataset
115
+ split_list = [openslr_male_splits, openslr_female_splits, indic_tts_male_splits, indic_tts_female_splits]
116
+ test_dataset = datasets.concatenate_datasets([split['test'] for split in split_list)
117
 
118
  wer = load_metric("wer")
119
 
121
  model = Wav2Vec2ForCTC.from_pretrained("gvs/wav2vec2-large-xlsr-malayalam")
122
  model.to("cuda")
123
 
124
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�Utrnle\_]'
125
+ unicode_ignore_regex = r'[\u200c\u200d\u200e]'
126
 
127
  resampler = torchaudio.transforms.Resample(48_000, 16_000)
128
 
129
  # Preprocessing the datasets.
130
  # We need to read the audio files as arrays
131
  def speech_file_to_array_fn(batch):
132
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"])
133
+ batch["sentence"] = re.sub(unicode_ignore_regex, '', batch["sentence"])
134
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
135
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
136
+ return batch
137
 
138
  test_dataset = test_dataset.map(speech_file_to_array_fn)
139
 
140
  # Preprocessing the datasets.
141
  # We need to read the audio files as arrays
142
  def evaluate(batch):
143
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
144
 
145
+ with torch.no_grad():
146
+ logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
147
 
148
+ pred_ids = torch.argmax(logits, dim=-1)
149
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
150
+ return batch
151
 
152
  result = test_dataset.map(evaluate, batched=True, batch_size=8)
153