jasper9w commited on
Commit
5fbd01c
1 Parent(s): 7a97ad4

add train.py

Browse files
Files changed (1) hide show
  1. train.py +181 -0
train.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import soundfile as sf
3
+ import torch
4
+ import evaluate
5
+ import numpy as np
6
+
7
+ from dataclasses import dataclass
8
+ from datasets import Dataset, DatasetDict
9
+ import evaluate
10
+ from transformers import (
11
+ WhisperFeatureExtractor,
12
+ WhisperTokenizer,
13
+ WhisperProcessor,
14
+ WhisperForConditionalGeneration,
15
+ Seq2SeqTrainer,
16
+ Seq2SeqTrainingArguments
17
+ )
18
+ from typing import Any, Dict, List, Union
19
+
20
+ # 加载特征提取器
21
+ def load_feature_extractor():
22
+ return WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
23
+
24
+ def load_dataset(directory, train_ratio):
25
+ def load_audio_data(dir):
26
+ data_dict = {'audio': [], 'sentence': []}
27
+ for filename in os.listdir(dir):
28
+ if filename.endswith('.wav'):
29
+ path = os.path.join(dir, filename)
30
+ data, samplerate = sf.read(path)
31
+ audio_dict = {'path': path, 'array': data, 'sampling_rate': samplerate}
32
+ data_dict['audio'].append(audio_dict)
33
+ sentence = filename.split('_')[0] # 获取文件名中的第一个部分作为句子
34
+ data_dict['sentence'].append(sentence)
35
+ return data_dict
36
+
37
+ def split_dataset(data_dict, train_ratio):
38
+ total_size = len(data_dict['audio'])
39
+ train_size = int(total_size * train_ratio)
40
+ indices = np.arange(total_size)
41
+ np.random.shuffle(indices)
42
+ train_indices, test_indices = indices[:train_size], indices[train_size:]
43
+
44
+ train_dict = {key: [value[i] for i in train_indices] for key, value in data_dict.items()}
45
+ test_dict = {key: [value[i] for i in test_indices] for key, value in data_dict.items()}
46
+
47
+ return Dataset.from_dict(train_dict), Dataset.from_dict(test_dict)
48
+
49
+ data_dict = load_audio_data(directory)
50
+ train_dataset, test_dataset = split_dataset(data_dict, train_ratio)
51
+
52
+ return DatasetDict({
53
+ 'train': train_dataset,
54
+ 'test': test_dataset
55
+ })
56
+
57
+ # 加载语音转换器
58
+ def load_tokenizer():
59
+ return WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="zh", task="transcribe")
60
+
61
+ # 准备数据集
62
+ def prepare_dataset(batch):
63
+ audio = batch["audio"]
64
+ batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
65
+ batch["labels"] = tokenizer(batch["sentence"]).input_ids
66
+ return batch
67
+
68
+ # 数据集整理
69
+ @dataclass
70
+ class DataCollatorSpeechSeq2SeqWithPadding:
71
+ processor: Any
72
+
73
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
74
+ # 整理特征
75
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
76
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
77
+
78
+ # 整理标签
79
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
80
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
81
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
82
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
83
+ labels = labels[:, 1:]
84
+
85
+ batch["labels"] = labels
86
+ return batch
87
+
88
+
89
+ metric = evaluate.load("cer")
90
+
91
+ # 计算指标
92
+ def compute_metrics(pred):
93
+ pred_ids = pred.predictions
94
+ label_ids = pred.label_ids
95
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
96
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
97
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
98
+
99
+ print("pred_str", pred_str)
100
+ print("label_str", label_str)
101
+ cer = 100 * metric.compute(predictions=pred_str, references=label_str)
102
+ return {"cer": cer}
103
+
104
+
105
+ # 训练模型
106
+ def train_model(train_dataset, eval_dataset, model, processor, output_dir):
107
+ training_args = Seq2SeqTrainingArguments(
108
+ output_dir=output_dir,
109
+ per_device_train_batch_size=16,
110
+ gradient_accumulation_steps=1,
111
+ learning_rate=1e-5,
112
+ warmup_steps=5,
113
+ max_steps=50,
114
+ gradient_checkpointing=True,
115
+ fp16=True,
116
+ evaluation_strategy="steps",
117
+ per_device_eval_batch_size=8,
118
+ predict_with_generate=True,
119
+ generation_max_length=225,
120
+ save_steps=10,
121
+ eval_steps=10,
122
+ logging_steps=5,
123
+ report_to=["tensorboard"],
124
+ load_best_model_at_end=True,
125
+ metric_for_best_model="cer",
126
+ greater_is_better=False,
127
+ push_to_hub=True
128
+ )
129
+
130
+ trainer = Seq2SeqTrainer(
131
+ args=training_args,
132
+ model=model,
133
+ train_dataset=train_dataset,
134
+ eval_dataset=eval_dataset,
135
+ data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=processor),
136
+ compute_metrics=compute_metrics,
137
+ tokenizer=processor.feature_extractor
138
+ )
139
+
140
+ processor.save_pretrained(training_args.output_dir)
141
+ trainer.train()
142
+
143
+
144
+ def load_my_dataset_with_cache():
145
+ import os
146
+ import pickle
147
+
148
+ cache_file = 'dataset_cache.pkl'
149
+
150
+ if os.path.exists(cache_file):
151
+ # 如果缓存文件存在,就直接从缓存中加载数据集
152
+ print("WAIN: load dataset from cache: {cache_file}")
153
+ with open(cache_file, 'rb') as f:
154
+ dataset = pickle.load(f)
155
+ return dataset
156
+ else:
157
+ # 否则,就加载并处理数据集,然后将其保存到缓存文件中
158
+ dataset = load_dataset('audios', 0.8)
159
+ dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"])
160
+
161
+ with open(cache_file, 'wb') as f:
162
+ pickle.dump(dataset, f)
163
+
164
+ return dataset
165
+
166
+ # 以下是主程序
167
+ if __name__ == "__main__":
168
+ # 加载模型和工具
169
+ feature_extractor = load_feature_extractor()
170
+ tokenizer = load_tokenizer()
171
+ processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="zh", task="transcribe")
172
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
173
+
174
+ model.config.forced_decoder_ids = None
175
+ model.config.suppress_tokens = []
176
+
177
+ # 加载数据集
178
+ dataset = load_my_dataset_with_cache()
179
+
180
+ # 训练模型
181
+ train_model(dataset["train"], dataset["test"], model, processor, "./whisper-tiny-zh4")