whisper / old /run_whisper.py
pere's picture
test
8b460b7
import torch
from datasets import load_dataset, DatasetDict
from datasets import Audio
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
# Functions
# Define a Data Collator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]}
for feature in features]
batch = self.processor.feature_extractor.pad(
input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]}
for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(
label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
def main():
# Metrics
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
# Prepare dataset
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
# Whisper Trainin Script
# Map the source and target columns
# Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset
source = "audio"
target = "sentence"
# Load a sample dataset
speech_data = DatasetDict()
# Examples
#speech_data["train"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="train", use_auth_token=True)
#speech_data["test"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="test", use_auth_token=True)
# speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True)
#speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True)
# The smallest dataset I found
speech_data["train"] = load_dataset(
"mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True)
speech_data["test"] = load_dataset(
"mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True)
# Rename columns
if "audio" not in speech_data.column_names["train"]:
speech_data = speech_data.rename_column(source, "audio")
if "sentence" not in speech_data.column_names["train"]:
speech_data = speech_data.rename_column(target, "sentence")
# Remove not needed columns - Not really sure if this is necessary
remove_list = [i for i in speech_data.column_names["train"]
if i not in ["audio", "sentence"]]
speech_data = speech_data.remove_columns(remove_list)
# Initialise
feature_extractor = WhisperFeatureExtractor.from_pretrained(
"openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-small", language="Norwegian", task="transcribe")
processor = WhisperProcessor.from_pretrained(
"openai/whisper-small", language="Norwegian", task="transcribe")
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# Prepare data
speech_data = speech_data.cast_column("audio", Audio(sampling_rate=16000))
speech_data = speech_data.map(
prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1)
# Metrics
metric = evaluate.load("wer")
# Initialise a Pretrained model
# We need to set use_cache=False here if we want to use gradient accumulation
model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-small", use_cache=False)
# Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
# Training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="../whisper-testrun1", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=2e-5,
warmup_steps=500,
max_steps=5000, # Changed from 4000
gradient_checkpointing=True,
group_by_length=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=500,
eval_steps=500,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=speech_data["train"],
eval_dataset=speech_data["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
# Start training
trainer.train()
def _mp_fn(index):
# For xla_spawn (TPUs)
print("The XLA is initiated")
main()
if __name__ == "__main__":
main()