whisper-small-da-cv11 / run_multiple.py
jstoone's picture
✨ Initial files and configs.
37c6f5f
raw
history blame
No virus
7.89 kB
from datasets import Audio, interleave_datasets, IterableDataset, load_dataset
from typing import List, Optional
dataset_names = ["mozilla-foundation/common_voice_11_0", "google/fleurs"]
dataset_config_names = ["da", "da_dk"]
text_column_names = ["sentence", "normalized_text", "text", "transcription"]
from datasets import interleave_datasets, load_dataset
def load_streaming_dataset(dataset_name, dataset_config_name, split, **kwargs):
if "+" in split:
# load multiple splits separated by the `+` symbol *with* streaming mode
dataset_splits = [load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs) for split_name in split.split("+")]
# interleave multiple splits to form one dataset
interleaved_dataset = interleave_datasets(dataset_splits)
return interleaved_dataset
else:
# load a single split *with* streaming mode
dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)
return dataset
from datasets import IterableDatasetDict
raw_datasets = IterableDatasetDict()
raw_datasets["train"] = load_streaming_dataset("mozilla-foundation/common_voice_11_0", "da", split="train+validation", use_auth_token=True) # set split="train+validation" for low-resource
raw_datasets["test"] = load_streaming_dataset("mozilla-foundation/common_voice_11_0", "da", split="test", use_auth_token=True)
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Danish", task="transcribe")
from datasets import Audio
raw_datasets = raw_datasets.cast_column("audio", Audio(sampling_rate=16000))
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()
def prepare_dataset(batch):
# load and (possibly) resample audio data to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# compute input length of audio sample in seconds
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
# optional pre-processing steps
transcription = batch["sentence"]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
# encode target text to label ids
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format("torch")
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
buffer_size=500,
seed=0,
)
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
vectorized_datasets["train"] = vectorized_datasets["train"].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
import evaluate
metric = evaluate.load("wer")
# evaluate with the 'normalised' WER
do_normalize_eval = True
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./",
per_device_train_batch_size=64,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-07,
warmup_steps=500,
max_steps=5000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=32,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
#optim="adamw_bnb_8bit"
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset
# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
class ShuffleCallback(TrainerCallback):
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
if isinstance(train_dataloader.dataset, IterableDatasetShard):
pass # set_epoch() is handled by the Trainer
elif isinstance(train_dataloader.dataset, IterableDataset):
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=vectorized_datasets["train"],
eval_dataset=vectorized_datasets["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor,
callbacks=[ShuffleCallback()],
)
model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)
trainer.train()
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0, FLEURS", # a 'pretty' name for the training dataset
"language": "da",
"model_name": "Whisper Small da - Common Voice+FLEURS", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "whisper-event",
}
trainer.push_to_hub(**kwargs)