File size: 6,303 Bytes
5fbd01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import soundfile as sf
import torch
import evaluate
import numpy as np

from dataclasses import dataclass
from datasets import Dataset, DatasetDict
import evaluate
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments
)
from typing import Any, Dict, List, Union

# 加载特征提取器
def load_feature_extractor():
    return WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

def load_dataset(directory, train_ratio):
    def load_audio_data(dir):
        data_dict = {'audio': [], 'sentence': []}
        for filename in os.listdir(dir):
            if filename.endswith('.wav'):
                path = os.path.join(dir, filename)
                data, samplerate = sf.read(path)
                audio_dict = {'path': path, 'array': data, 'sampling_rate': samplerate}
                data_dict['audio'].append(audio_dict)
                sentence = filename.split('_')[0]  # 获取文件名中的第一个部分作为句子
                data_dict['sentence'].append(sentence)
        return data_dict

    def split_dataset(data_dict, train_ratio):
        total_size = len(data_dict['audio'])
        train_size = int(total_size * train_ratio)
        indices = np.arange(total_size)
        np.random.shuffle(indices)
        train_indices, test_indices = indices[:train_size], indices[train_size:]
        
        train_dict = {key: [value[i] for i in train_indices] for key, value in data_dict.items()}
        test_dict = {key: [value[i] for i in test_indices] for key, value in data_dict.items()}
        
        return Dataset.from_dict(train_dict), Dataset.from_dict(test_dict)

    data_dict = load_audio_data(directory)
    train_dataset, test_dataset = split_dataset(data_dict, train_ratio)

    return DatasetDict({
        'train': train_dataset,
        'test': test_dataset
    })

# 加载语音转换器
def load_tokenizer():
    return WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="zh", task="transcribe")

# 准备数据集
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

# 数据集整理
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # 整理特征
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # 整理标签
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch


metric = evaluate.load("cer")
    
# 计算指标
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    print("pred_str", pred_str)
    print("label_str", label_str)
    cer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}


# 训练模型
def train_model(train_dataset, eval_dataset, model, processor, output_dir):
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        learning_rate=1e-5,
        warmup_steps=5,
        max_steps=50,
        gradient_checkpointing=True,
        fp16=True,
        evaluation_strategy="steps",
        per_device_eval_batch_size=8,
        predict_with_generate=True,
        generation_max_length=225,
        save_steps=10,
        eval_steps=10,
        logging_steps=5,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="cer",
        greater_is_better=False,
        push_to_hub=True
    )

    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=processor),
        compute_metrics=compute_metrics,
        tokenizer=processor.feature_extractor
    )

    processor.save_pretrained(training_args.output_dir)
    trainer.train()

    
def load_my_dataset_with_cache():
    import os
    import pickle

    cache_file = 'dataset_cache.pkl'

    if os.path.exists(cache_file):
        # 如果缓存文件存在,就直接从缓存中加载数据集
        print("WAIN: load dataset from cache: {cache_file}")
        with open(cache_file, 'rb') as f:
            dataset = pickle.load(f)
            return dataset
    else:
        # 否则,就加载并处理数据集,然后将其保存到缓存文件中
        dataset = load_dataset('audios', 0.8)
        dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"])

        with open(cache_file, 'wb') as f:
            pickle.dump(dataset, f)
            
        return dataset

# 以下是主程序
if __name__ == "__main__":
    # 加载模型和工具
    feature_extractor = load_feature_extractor()
    tokenizer = load_tokenizer()
    processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="zh", task="transcribe")
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []

    # 加载数据集
    dataset = load_my_dataset_with_cache()
    
    # 训练模型
    train_model(dataset["train"], dataset["test"], model, processor, "./whisper-tiny-zh4")