Spaces:
Sleeping
Sleeping
import os | |
import soundfile as sf | |
import torch | |
import evaluate | |
import numpy as np | |
from dataclasses import dataclass | |
from datasets import Dataset, DatasetDict | |
import evaluate | |
from transformers import ( | |
WhisperFeatureExtractor, | |
WhisperTokenizer, | |
WhisperProcessor, | |
WhisperForConditionalGeneration, | |
Seq2SeqTrainer, | |
Seq2SeqTrainingArguments | |
) | |
from typing import Any, Dict, List, Union | |
# 加载特征提取器 | |
def load_feature_extractor(): | |
return WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny") | |
def load_dataset(directory, train_ratio): | |
def load_audio_data(dir): | |
data_dict = {'audio': [], 'sentence': []} | |
for filename in os.listdir(dir): | |
if filename.endswith('.wav'): | |
path = os.path.join(dir, filename) | |
data, samplerate = sf.read(path) | |
audio_dict = {'path': path, 'array': data, 'sampling_rate': samplerate} | |
data_dict['audio'].append(audio_dict) | |
sentence = filename.split('_')[0] # 获取文件名中的第一个部分作为句子 | |
data_dict['sentence'].append(sentence) | |
return data_dict | |
def split_dataset(data_dict, train_ratio): | |
total_size = len(data_dict['audio']) | |
train_size = int(total_size * train_ratio) | |
indices = np.arange(total_size) | |
np.random.shuffle(indices) | |
train_indices, test_indices = indices[:train_size], indices[train_size:] | |
train_dict = {key: [value[i] for i in train_indices] for key, value in data_dict.items()} | |
test_dict = {key: [value[i] for i in test_indices] for key, value in data_dict.items()} | |
return Dataset.from_dict(train_dict), Dataset.from_dict(test_dict) | |
data_dict = load_audio_data(directory) | |
train_dataset, test_dataset = split_dataset(data_dict, train_ratio) | |
return DatasetDict({ | |
'train': train_dataset, | |
'test': test_dataset | |
}) | |
# 加载语音转换器 | |
def load_tokenizer(): | |
return WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="zh", task="transcribe") | |
# 准备数据集 | |
def prepare_dataset(batch): | |
audio = batch["audio"] | |
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] | |
batch["labels"] = tokenizer(batch["sentence"]).input_ids | |
return batch | |
# 数据集整理 | |
class DataCollatorSpeechSeq2SeqWithPadding: | |
processor: Any | |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
# 整理特征 | |
input_features = [{"input_features": feature["input_features"]} for feature in features] | |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") | |
# 整理标签 | |
label_features = [{"input_ids": feature["labels"]} for feature in features] | |
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") | |
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) | |
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): | |
labels = labels[:, 1:] | |
batch["labels"] = labels | |
return batch | |
metric = evaluate.load("cer") | |
# 计算指标 | |
def compute_metrics(pred): | |
pred_ids = pred.predictions | |
label_ids = pred.label_ids | |
label_ids[label_ids == -100] = tokenizer.pad_token_id | |
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) | |
print("pred_str", pred_str) | |
print("label_str", label_str) | |
cer = 100 * metric.compute(predictions=pred_str, references=label_str) | |
return {"cer": cer} | |
# 训练模型 | |
def train_model(train_dataset, eval_dataset, model, processor, output_dir): | |
training_args = Seq2SeqTrainingArguments( | |
output_dir=output_dir, | |
per_device_train_batch_size=16, | |
gradient_accumulation_steps=1, | |
learning_rate=1e-5, | |
warmup_steps=5, | |
max_steps=50, | |
gradient_checkpointing=True, | |
fp16=True, | |
evaluation_strategy="steps", | |
per_device_eval_batch_size=8, | |
predict_with_generate=True, | |
generation_max_length=225, | |
save_steps=10, | |
eval_steps=10, | |
logging_steps=5, | |
report_to=["tensorboard"], | |
load_best_model_at_end=True, | |
metric_for_best_model="cer", | |
greater_is_better=False, | |
push_to_hub=True | |
) | |
trainer = Seq2SeqTrainer( | |
args=training_args, | |
model=model, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=processor), | |
compute_metrics=compute_metrics, | |
tokenizer=processor.feature_extractor | |
) | |
processor.save_pretrained(training_args.output_dir) | |
trainer.train() | |
def load_my_dataset_with_cache(): | |
import os | |
import pickle | |
cache_file = 'dataset_cache.pkl' | |
if os.path.exists(cache_file): | |
# 如果缓存文件存在,就直接从缓存中加载数据集 | |
print("WAIN: load dataset from cache: {cache_file}") | |
with open(cache_file, 'rb') as f: | |
dataset = pickle.load(f) | |
return dataset | |
else: | |
# 否则,就加载并处理数据集,然后将其保存到缓存文件中 | |
dataset = load_dataset('audios', 0.8) | |
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"]) | |
with open(cache_file, 'wb') as f: | |
pickle.dump(dataset, f) | |
return dataset | |
# 以下是主程序 | |
if __name__ == "__main__": | |
# 加载模型和工具 | |
feature_extractor = load_feature_extractor() | |
tokenizer = load_tokenizer() | |
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="zh", task="transcribe") | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") | |
model.config.forced_decoder_ids = None | |
model.config.suppress_tokens = [] | |
# 加载数据集 | |
dataset = load_my_dataset_with_cache() | |
# 训练模型 | |
train_model(dataset["train"], dataset["test"], model, processor, "./whisper-tiny-zh4") |