whisper-ru
Collection
stt models
•
2 items
•
Updated
device = "cuda:2" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "wyluilipe/whisper-ru-v2.0"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
generate_kwargs = {"language": lang}
output = self.pipe(filename, generate_kwargs=generate_kwargs)
from datasets import load_from_disk
from huggingface_hub import login
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
HF_TOKEN = '' # print your token
login(HF_TOKEN)
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
batch["labels"] = tokenizer(batch["transcription"]).input_ids
return batch
def filter_bad_transcriptions(batch):
return batch["transcription"] is not None and batch["transcription"].strip() != ""
device = torch.device("cpu")
train_dataset = load_from_disk('train_dataset')
test_dataset = load_from_disk('test_dataset')
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3-turbo")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3-turbo", language="Russian", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3-turbo", language="Russian", task="transcribe")
train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names)
test_dataset = test_dataset.filter(filter_bad_transcriptions)
test_dataset = test_dataset.map(prepare_dataset, remove_columns=test_dataset.column_names)
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3-turbo").to(device)
model.generation_config.language = "russian"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return {k: v.to(device) for k, v in batch.items()}
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
def wer(predictions, references):
def count_words(seq):
return seq.split()
def levenshtein(seq1, seq2):
len_seq1 = len(seq1) + 1
len_seq2 = len(seq2) + 1
dp_matrix = [[0] * len_seq2 for _ in range(len_seq1)]
for i in range(len_seq1):
dp_matrix[i][0] = i
for j in range(len_seq2):
dp_matrix[0][j] = j
for i in range(1, len_seq1):
for j in range(1, len_seq2):
if seq1[i - 1] == seq2[j - 1]:
dp_matrix[i][j] = dp_matrix[i - 1][j - 1]
else:
dp_matrix[i][j] = min(dp_matrix[i - 1][j], dp_matrix[i][j - 1], dp_matrix[i - 1][j - 1]) + 1
return dp_matrix[-1][-1]
total_words = 0
total_errors = 0
for pred, ref in zip(predictions, references):
pred_words = count_words(pred)
ref_words = count_words(ref)
total_words += len(ref_words)
total_errors += levenshtein(pred_words, ref_words)
return total_errors / total_words if total_words > 0 else 0
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
label_ids[label_ids == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer_value = 100 * wer(pred_str, label_str)
return {"wer": wer_value}
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-ru", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=80,
max_steps=400,
gradient_checkpointing=False,
fp16=False, # on gpu True
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=100,
eval_steps=100,
logging_steps=10,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
processor.save_pretrained(training_args.output_dir)
trainer.train()