jefffffff9 Claude Sonnet 4.6 commited on
Commit
6682858
·
1 Parent(s): 58f431a

Fix jiwer crash on post-normalisation empty refs; register SLR106/105 datasets

Browse files

compute_metrics: pre-apply lowercasing + punctuation removal before filtering
empty references — a label like "?" passes r.strip() but becomes "" after
RemovePunctuation, causing jiwer to raise ValueError mid-evaluation.
Pass already-normalised strings directly to jiwer.cer / jiwer.wer so no
second transform is applied.

web_harvester: add commented-in registry entries for OpenSLR SLR106
(10k clean Guinea Pular utterances, 49 speakers, CC BY-SA 4.0) and SLR105
(Guinea radio corpus, Pular-tagged validation split). Both are uncommented
once the audio is uploaded to HF — best available Guinea Pular ASR data.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

notebooks/kaggle_master_trainer.ipynb CHANGED
@@ -184,7 +184,7 @@
184
  "metadata": {},
185
  "outputs": [],
186
  "source": [
187
- "# -- Cell 14: Data collator + CER metric --------------------------------------\nimport jiwer\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List\n\ntransform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n jiwer.ReduceToListOfListOfWords(),\n])\n\n# CER transform (no word-split step needed)\n_cer_transform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n])\n\n\n@dataclass\nclass DataCollatorSpeechSeq2SeqWithPadding:\n processor: Any\n\n def __call__(self, features: List[Dict]) -> Dict:\n import torch\n input_feats = [{'input_features': f['input_features']} for f in features]\n batch = self.processor.feature_extractor.pad(input_feats, return_tensors='pt')\n\n # Leave features in fp32 -- AMP (fp16=True in TrainingArgs) handles casting\n\n label_feats = [{'input_ids': f['labels']} for f in features]\n labels_batch = self.processor.tokenizer.pad(label_feats, return_tensors='pt')\n labels = labels_batch['input_ids'].masked_fill(\n labels_batch.attention_mask.ne(1), -100\n )\n if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().item():\n labels = labels[:, 1:]\n batch['labels'] = labels\n return batch\n\n\ndef compute_metrics(pred):\n pred_ids = pred.predictions\n label_ids = pred.label_ids\n label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n\n # Filter out empty references (silent audio / all-padding labels) that\n # cause jiwer to raise ValueError (\"non-empty list of strings\" required).\n pairs = [(r, h) for r, h in zip(label_str, pred_str) if r.strip()]\n if not pairs:\n return {'cer': 0.0, 'wer': 0.0}\n label_str, pred_str = zip(*pairs)\n\n cer = jiwer.cer(\n list(label_str), list(pred_str),\n reference_transform=_cer_transform,\n hypothesis_transform=_cer_transform,\n )\n wer = jiwer.wer(list(label_str), list(pred_str),\n hypothesis_transform=transform,\n reference_transform=transform)\n return {'cer': round(cer, 4), 'wer': round(wer, 4)}\n\n\ncollator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\nprint('Collator and WER metric ready')"
188
  ]
189
  },
190
  {
 
184
  "metadata": {},
185
  "outputs": [],
186
  "source": [
187
+ "# -- Cell 14: Data collator + CER metric --------------------------------------\nimport jiwer\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List\n\ntransform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n jiwer.ReduceToListOfListOfWords(),\n])\n\n# CER transform (no word-split step needed)\n_cer_transform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n])\n\n\n@dataclass\nclass DataCollatorSpeechSeq2SeqWithPadding:\n processor: Any\n\n def __call__(self, features: List[Dict]) -> Dict:\n import torch\n input_feats = [{'input_features': f['input_features']} for f in features]\n batch = self.processor.feature_extractor.pad(input_feats, return_tensors='pt')\n\n # Leave features in fp32 -- AMP (fp16=True in TrainingArgs) handles casting\n\n label_feats = [{'input_ids': f['labels']} for f in features]\n labels_batch = self.processor.tokenizer.pad(label_feats, return_tensors='pt')\n labels = labels_batch['input_ids'].masked_fill(\n labels_batch.attention_mask.ne(1), -100\n )\n if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().item():\n labels = labels[:, 1:]\n batch['labels'] = labels\n return batch\n\n\ndef _apply_jiwer_transform(texts, t):\n \"\"\"Apply a jiwer Compose transform and return plain strings (not nested lists).\"\"\"\n import re as _re\n result = []\n for s in texts:\n s = s.lower()\n s = _re.sub(r'[^\\w\\s]', '', s) # RemovePunctuation equivalent\n s = ' '.join(s.split()) # RemoveMultipleSpaces + Strip\n result.append(s)\n return result\n\n\ndef compute_metrics(pred):\n pred_ids = pred.predictions\n label_ids = pred.label_ids\n label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n\n # Pre-normalise so we can filter empties AFTER the transform.\n # A reference like \"?\" or \"1\" decodes to non-empty but becomes empty\n # after punctuation/number removal -- jiwer crashes on empty references.\n norm_ref = _apply_jiwer_transform(label_str, _cer_transform)\n norm_hyp = _apply_jiwer_transform(pred_str, _cer_transform)\n pairs = [(r, h) for r, h in zip(norm_ref, norm_hyp) if r.strip()]\n if not pairs:\n return {'cer': 0.0, 'wer': 0.0}\n ref_clean, hyp_clean = zip(*pairs)\n\n cer = jiwer.cer(list(ref_clean), list(hyp_clean)) # already normalised\n wer = jiwer.wer(list(ref_clean), list(hyp_clean)) # already normalised\n return {'cer': round(cer, 4), 'wer': round(wer, 4)}\n\n\ncollator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\nprint('Collator and WER metric ready')"
188
  ]
189
  },
190
  {
src/data/web_harvester.py CHANGED
@@ -69,6 +69,35 @@ HF_DATASET_REGISTRY = {
69
  "adlam": True,
70
  "note": "51 Adlam-script audio rows — converted to Latin before training",
71
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  ],
73
  }
74
 
 
69
  "adlam": True,
70
  "note": "51 Adlam-script audio rows — converted to Latin before training",
71
  },
72
+ # OpenSLR SLR106 — West African Virtual Assistant ASR Corpus (Guinea, CC BY-SA 4.0)
73
+ # 10,083 clean utterances from 49 Guinea-native Pular speakers (read speech,
74
+ # multi-device, ages 5–76). Best available clean Guinea Pular ASR data.
75
+ # Download manually from https://openslr.org/106/ and upload to HF before training.
76
+ # Uncomment once the dataset repo is populated.
77
+ # {
78
+ # "repo": "ous-sow/slr106-pular",
79
+ # "config": "default",
80
+ # "split": "train",
81
+ # "audio_col": "audio",
82
+ # "text_col": "transcription",
83
+ # "max": 10_000,
84
+ # "license": "cc-by-sa-4.0",
85
+ # "note": "OpenSLR SLR106 Guinea Pular — 10k clean utterances, 49 speakers",
86
+ # },
87
+ # OpenSLR SLR105 — West African Radio Corpus (Guinea, CC BY-SA 4.0)
88
+ # ~142 hours raw radio audio from 6 Guinea stations; Pular-tagged validation
89
+ # set of 300 clips. Noisier than SLR106 but larger.
90
+ # Uncomment once uploaded to HF.
91
+ # {
92
+ # "repo": "ous-sow/slr105-pular",
93
+ # "config": "default",
94
+ # "split": "validation",
95
+ # "audio_col": "audio",
96
+ # "text_col": "transcription",
97
+ # "max": 300,
98
+ # "license": "cc-by-sa-4.0",
99
+ # "note": "OpenSLR SLR105 Guinea radio — Pular-tagged validation split",
100
+ # },
101
  ],
102
  }
103