yqzhangjx commited on
Commit
674fbd3
1 Parent(s): e9a78de

Upload peft_lora_whisper-large-v2.ipynb

Browse files
Files changed (1) hide show
  1. peft_lora_whisper-large-v2.ipynb +1109 -0
peft_lora_whisper-large-v2.ipynb ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "c219841f-493c-40f9-a6c9-3700f0c525d0",
6
+ "metadata": {},
7
+ "source": [
8
+ "# PEFT 库 LoRA 实战 - OpenAI Whisper-large-v2\n",
9
+ "\n",
10
+ "本教程使用 LoRA 在`OpenAI Whisper-large-v2`模型上实现`语音识别(ASR)`任务的微调训练。\n",
11
+ "\n",
12
+ "我们还结合了`int8` 量化进一步降低训练过程资源开销,同时保证了精度几乎不受影响。"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "6d0a1e23-ea71-45d6-82d6-453077cf2d29",
18
+ "metadata": {},
19
+ "source": [
20
+ "## 全局参数设置"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 1,
26
+ "id": "19d11aa3-9a73-4ce9-b6c5-a65a2fcb07c3",
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "1"
33
+ ]
34
+ },
35
+ "execution_count": 1,
36
+ "metadata": {},
37
+ "output_type": "execute_result"
38
+ }
39
+ ],
40
+ "source": [
41
+ "import os\n",
42
+ "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" # see issue #152\n",
43
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\"\n",
44
+ "\n",
45
+ "import torch\n",
46
+ "torch.cuda.device_count()"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 2,
52
+ "id": "ccd00402-d821-485e-8703-fb16bcb56a9e",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "model_name_or_path = \"openai/whisper-large-v2\"\n",
57
+ "language = \"Chinese (China)\"\n",
58
+ "language_abbr = \"zh-CN\"\n",
59
+ "task = \"transcribe\"\n",
60
+ "dataset_name = \"mozilla-foundation/common_voice_11_0\"\n",
61
+ "\n",
62
+ "batch_size=64"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "id": "cfffa1df-e51e-4026-9817-1cebddf0061a",
68
+ "metadata": {},
69
+ "source": [
70
+ "## 下载数据集 Common Voice\n",
71
+ "\n",
72
+ "Common Voice 11.0 数据集包含许多不同语言的录音,总时长达数小时。\n",
73
+ "\n",
74
+ "本教程以中文数据为例,展示如何使用 LoRA 在 Whisper-large-v2 上进行微调训练。\n",
75
+ "\n",
76
+ "首先,初始化一个DatasetDict结构,并将训练集(将训练+验证拆分为训练集)和测试集拆分好,按照中文数据集构建配置加载到内存中:"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 3,
82
+ "id": "21ff42f4-f3ec-46d3-b0c0-dd9ffbf7b50b",
83
+ "metadata": {
84
+ "scrolled": true
85
+ },
86
+ "outputs": [
87
+ {
88
+ "data": {
89
+ "text/plain": [
90
+ "{'client_id': '95368aab163e0387e4fd4991b4f2d8ccfbd4364bf656c860230501fd27dcedf087773e4695a6cf5de9c4f1d406d582283190d065cdfa36b0e2b060cffaca977e',\n",
91
+ " 'path': '/store/jxzhang/.cache/huggingface/datasets/downloads/extracted/edf8cf7fef3457433a3a59929c4c4809972172377467a8f189ac185f3d5e4b53/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
92
+ " 'audio': {'path': '/store/jxzhang/.cache/huggingface/datasets/downloads/extracted/edf8cf7fef3457433a3a59929c4c4809972172377467a8f189ac185f3d5e4b53/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
93
+ " 'array': array([-6.82121026e-13, -2.27373675e-12, -2.27373675e-12, ...,\n",
94
+ " 1.21667399e-05, 3.23003678e-06, -2.43066324e-07]),\n",
95
+ " 'sampling_rate': 48000},\n",
96
+ " 'sentence': '性喜温暖润湿气候且耐寒。',\n",
97
+ " 'up_votes': 2,\n",
98
+ " 'down_votes': 0,\n",
99
+ " 'age': '',\n",
100
+ " 'gender': '',\n",
101
+ " 'accent': '',\n",
102
+ " 'locale': 'zh-CN',\n",
103
+ " 'segment': ''}"
104
+ ]
105
+ },
106
+ "execution_count": 3,
107
+ "metadata": {},
108
+ "output_type": "execute_result"
109
+ }
110
+ ],
111
+ "source": [
112
+ "from datasets import load_dataset\n",
113
+ "from datasets import load_dataset, DatasetDict\n",
114
+ "\n",
115
+ "common_voice = DatasetDict()\n",
116
+ "\n",
117
+ "common_voice[\"train\"] = load_dataset(dataset_name, language_abbr, split=\"train+validation\", trust_remote_code=True)\n",
118
+ "common_voice[\"test\"] = load_dataset(dataset_name, language_abbr, split=\"test\", trust_remote_code=True)\n",
119
+ "common_voice[\"train\"][0]"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "id": "3c81faa4-d8fe-4cc7-afe6-4c2615b9050f",
125
+ "metadata": {},
126
+ "source": [
127
+ "## 预处理训练数据集\n"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 4,
133
+ "id": "5822025f-7f8e-4141-8bfe-d8822d0da20f",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "from transformers import AutoFeatureExtractor, AutoTokenizer, AutoProcessor\n",
138
+ "\n",
139
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)\n",
140
+ "\n",
141
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
142
+ " model_name_or_path, language=language, task=task)\n",
143
+ "\n",
144
+ "processor = AutoProcessor.from_pretrained(\n",
145
+ " model_name_or_path, language=language, task=task)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "id": "f394e5cd-23b8-413e-8bde-88c3542b84fa",
151
+ "metadata": {},
152
+ "source": [
153
+ "#### 移除数据集中不必要的字段"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 5,
159
+ "id": "1690dc5a-c1f7-4556-9be3-d31ad888e52e",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "common_voice = common_voice.remove_columns(\n",
164
+ " [\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"path\", \"segment\", \"up_votes\"]\n",
165
+ ")"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 6,
171
+ "id": "309aff16-ea26-4474-af54-7ef244783999",
172
+ "metadata": {},
173
+ "outputs": [
174
+ {
175
+ "data": {
176
+ "text/plain": [
177
+ "{'audio': {'path': '/store/jxzhang/.cache/huggingface/datasets/downloads/extracted/edf8cf7fef3457433a3a59929c4c4809972172377467a8f189ac185f3d5e4b53/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
178
+ " 'array': array([-6.82121026e-13, -2.27373675e-12, -2.27373675e-12, ...,\n",
179
+ " 1.21667399e-05, 3.23003678e-06, -2.43066324e-07]),\n",
180
+ " 'sampling_rate': 48000},\n",
181
+ " 'sentence': '性喜温暖润湿气候且耐寒。'}"
182
+ ]
183
+ },
184
+ "execution_count": 6,
185
+ "metadata": {},
186
+ "output_type": "execute_result"
187
+ }
188
+ ],
189
+ "source": [
190
+ "common_voice[\"train\"][0]"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "881546ab-72e4-4bcf-852f-a8be736164b7",
196
+ "metadata": {},
197
+ "source": [
198
+ "#### 降采样音频数据\n",
199
+ "\n",
200
+ "查看`common_voice` 数据集介绍,你会发现其音频是以48kHz的采样率进行采样的.\n",
201
+ "\n",
202
+ "而`Whisper`模型是在16kHZ的音频输入上预训练的,因此我们需要将音频输入降采样以匹配模型预训练时使用的采样率。\n",
203
+ "\n",
204
+ "通过在音频列上使用`cast_column`方法,并将`sampling_rate`设置为16kHz来对音频进行降采样。\n",
205
+ "\n",
206
+ "下次调用时,音频输入将实时重新取样:"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 7,
212
+ "id": "5fc451cc-e21e-473c-a702-d7d6ed098f91",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "from datasets import Audio\n",
217
+ "\n",
218
+ "common_voice = common_voice.cast_column(\"audio\", Audio(sampling_rate=16000))"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 8,
224
+ "id": "cc3d7fcc-7c34-41c8-9857-5a6e883f6115",
225
+ "metadata": {},
226
+ "outputs": [
227
+ {
228
+ "data": {
229
+ "text/plain": [
230
+ "{'audio': {'path': '/store/jxzhang/.cache/huggingface/datasets/downloads/extracted/edf8cf7fef3457433a3a59929c4c4809972172377467a8f189ac185f3d5e4b53/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
231
+ " 'array': array([ 5.09317033e-11, -7.27595761e-12, -6.54836185e-11, ...,\n",
232
+ " -5.96661994e-06, 2.71382887e-05, 1.29687978e-05]),\n",
233
+ " 'sampling_rate': 16000},\n",
234
+ " 'sentence': '性喜温暖润湿气候且耐寒。'}"
235
+ ]
236
+ },
237
+ "execution_count": 8,
238
+ "metadata": {},
239
+ "output_type": "execute_result"
240
+ }
241
+ ],
242
+ "source": [
243
+ "common_voice[\"train\"][0]"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "id": "ee55908f-3ea3-4aee-8062-6f8d3a6573b9",
249
+ "metadata": {},
250
+ "source": [
251
+ "### 整合以上数据处理为一个函数\n",
252
+ "\n",
253
+ "该数据预处理函数应该包括:\n",
254
+ "- 通过加载音频列将音频输入重新采样为16kHZ。\n",
255
+ "- 使用特征提取器从音频数组计算输入特征。\n",
256
+ "- 将句子列标记化为输入标签。"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 9,
262
+ "id": "58f42c35-35ba-4d6b-9d15-095963cec67c",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "def prepare_dataset(batch):\n",
267
+ " audio = batch[\"audio\"]\n",
268
+ " batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
269
+ " batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
270
+ " return batch"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 10,
276
+ "id": "392f7856-a720-40a7-af7e-40e185fc315b",
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "common_voice = common_voice.map(\n",
281
+ " prepare_dataset, remove_columns=common_voice.column_names[\"train\"]\n",
282
+ ")"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "id": "84ec184e-d840-40b6-99af-d11392273442",
288
+ "metadata": {},
289
+ "source": [
290
+ "创建一个`DataCollator`类来将每个批次中的`attention_mask`填充到最大长度,并用`-100`替换填充值,以便在损失函数中被忽略。\n",
291
+ "\n",
292
+ "然后初始化数据收集器的实例:"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 11,
298
+ "id": "4c89ffcf-c805-48c2-b7d3-ae01b687178c",
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "import torch\n",
303
+ "\n",
304
+ "from dataclasses import dataclass\n",
305
+ "from typing import Any, Dict, List, Union\n",
306
+ "\n",
307
+ "\n",
308
+ "@dataclass\n",
309
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
310
+ " processor: Any\n",
311
+ "\n",
312
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
313
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
314
+ " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
315
+ "\n",
316
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
317
+ " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
318
+ "\n",
319
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
320
+ "\n",
321
+ " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
322
+ " labels = labels[:, 1:]\n",
323
+ "\n",
324
+ " batch[\"labels\"] = labels\n",
325
+ "\n",
326
+ " return batch\n",
327
+ "\n",
328
+ "\n",
329
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "id": "80ecd4bc-01fd-4286-afe5-fe2639ae15a1",
335
+ "metadata": {},
336
+ "source": [
337
+ "## 训练模型"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 12,
343
+ "id": "f9fcb121-fa5c-4c30-8bdc-9ab08ab75427",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "from transformers import AutoModelForSpeechSeq2Seq\n",
348
+ "\n",
349
+ "model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, load_in_8bit=True, device_map=\"auto\")"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": 13,
355
+ "id": "2cb016f1-e6e9-4fd8-9c8b-72fd23be92d3",
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "model.config.forced_decoder_ids = None\n",
360
+ "model.config.suppress_tokens = []"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "markdown",
365
+ "id": "25ba1fa0-ea15-48d9-8c16-70df9f0b60b1",
366
+ "metadata": {},
367
+ "source": [
368
+ "为了准备模型进行int8量化,使用 `prepare_model_for_int8_training` 函数来处理模型:\n",
369
+ "- 将所有非int8模块转换为完全精度(fp32)以保持稳定性\n",
370
+ "- 在输入嵌入层上添加前向钩子,计算输入隐藏状态的梯度\n",
371
+ "- 启用渐变检查点以进行更高效的内存训练"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": 14,
377
+ "id": "1ee34359-fe1b-48f1-827c-6a8ec4a53af7",
378
+ "metadata": {},
379
+ "outputs": [
380
+ {
381
+ "name": "stderr",
382
+ "output_type": "stream",
383
+ "text": [
384
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/peft/utils/other.py:141: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n",
385
+ " warnings.warn(\n"
386
+ ]
387
+ }
388
+ ],
389
+ "source": [
390
+ "from peft import prepare_model_for_int8_training\n",
391
+ "\n",
392
+ "model = prepare_model_for_int8_training(model)"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": null,
398
+ "id": "24b6f8a2-867f-4ed5-bad5-15ca9fd9547c",
399
+ "metadata": {},
400
+ "outputs": [],
401
+ "source": []
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 15,
406
+ "id": "cdf6bc9c-6d2c-4dbf-b09e-a89cb1041c46",
407
+ "metadata": {},
408
+ "outputs": [],
409
+ "source": [
410
+ "from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model\n",
411
+ "\n",
412
+ "config = LoraConfig(\n",
413
+ " r=8,\n",
414
+ " lora_alpha=64,\n",
415
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
416
+ " lora_dropout=0.05,\n",
417
+ " bias=\"none\")"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": 16,
423
+ "id": "b74c7508-e6f4-42d8-8aaf-fe83c5977c35",
424
+ "metadata": {},
425
+ "outputs": [
426
+ {
427
+ "name": "stdout",
428
+ "output_type": "stream",
429
+ "text": [
430
+ "trainable params: 3,932,160 || all params: 1,547,237,120 || trainable%: 0.25414074863974306\n"
431
+ ]
432
+ }
433
+ ],
434
+ "source": [
435
+ "model = get_peft_model(model, config)\n",
436
+ "model.print_trainable_parameters()"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "markdown",
441
+ "id": "1cc6b26a-3e54-4a46-9b36-a048b40a37d7",
442
+ "metadata": {},
443
+ "source": [
444
+ "### 演示需要,只训练了100 steps。建议同学改为默认的 3个 epochs 完整训练一个中文语音识别模型。"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": 17,
450
+ "id": "11f259c8-dbcf-4a7f-bbb5-821ab104efee",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "from transformers import Seq2SeqTrainingArguments\n",
455
+ "\n",
456
+ "# 设置序列到序列模型训练的参数\n",
457
+ "training_args = Seq2SeqTrainingArguments(\n",
458
+ " output_dir=\"models/whisper-large-v2-asr-int8\", # 指定模型输出和保存的目录\n",
459
+ " per_device_train_batch_size=batch_size, # 每个设备上的训练批量大小\n",
460
+ " gradient_accumulation_steps=1, # 梯度累积步数,在每次优化器步骤之前累积的更新步数\n",
461
+ " learning_rate=1e-3, # 学习率\n",
462
+ " warmup_steps=50, # 在训练初期增加学习率的步数,有助于稳定训练\n",
463
+ " # max_steps=100, # 训练总步数\n",
464
+ " num_train_epochs=3, # 训练的总轮数\n",
465
+ " evaluation_strategy=\"epoch\", # 设置评估策略,这里是在每个epoch结束时进行评估\n",
466
+ " fp16=True, # 启用混合精度训练,可以提高训练速度,同时减少内存使用\n",
467
+ " per_device_eval_batch_size=batch_size, # 每个设备上的评估批量大小\n",
468
+ " generation_max_length=128, # 生成任务的最大长度\n",
469
+ " logging_steps=25, # 指定日志记录的步骤,用于跟踪训练进度\n",
470
+ " remove_unused_columns=False, # 是否删除不使用的列,以减少数据处理开销\n",
471
+ " label_names=[\"labels\"], # 指定标签列的名称,用于训练过程中\n",
472
+ ")"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "markdown",
477
+ "id": "c57ee183-b16f-4313-97f6-0df6c0f5f467",
478
+ "metadata": {},
479
+ "source": [
480
+ "#### 训练过程保存状态的回调,长时期训练建议使用"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": 18,
486
+ "id": "2ce443d9-f309-4c03-bd74-c6842292b713",
487
+ "metadata": {},
488
+ "outputs": [],
489
+ "source": [
490
+ "import os\n",
491
+ "from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n",
492
+ "from transformers import Seq2SeqTrainer, TrainerCallback, Seq2SeqTrainingArguments, TrainerState, TrainerControl\n",
493
+ "\n",
494
+ "class SavePeftModelCallback(TrainerCallback):\n",
495
+ " def on_save(\n",
496
+ " self,\n",
497
+ " args: Seq2SeqTrainingArguments,\n",
498
+ " state: TrainerState,\n",
499
+ " control: TrainerControl,\n",
500
+ " **kwargs,\n",
501
+ " ):\n",
502
+ " checkpoint_folder = os.path.join(args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\")\n",
503
+ "\n",
504
+ " peft_model_path = os.path.join(checkpoint_folder, \"adapter_model\")\n",
505
+ " kwargs[\"model\"].save_pretrained(peft_model_path)\n",
506
+ "\n",
507
+ " pytorch_model_path = os.path.join(checkpoint_folder, \"pytorch_model.bin\")\n",
508
+ " if os.path.exists(pytorch_model_path):\n",
509
+ " os.remove(pytorch_model_path)\n",
510
+ " return control"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": 19,
516
+ "id": "f8a52ed7-cae0-4aba-818e-87717430d908",
517
+ "metadata": {},
518
+ "outputs": [],
519
+ "source": [
520
+ "trainer = Seq2SeqTrainer(\n",
521
+ " args=training_args,\n",
522
+ " model=model,\n",
523
+ " train_dataset=common_voice[\"train\"],\n",
524
+ " eval_dataset=common_voice[\"test\"],\n",
525
+ " data_collator=data_collator,\n",
526
+ " tokenizer=processor.feature_extractor,\n",
527
+ " callbacks=[SavePeftModelCallback],\n",
528
+ ")\n",
529
+ "model.config.use_cache = False"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": 20,
535
+ "id": "6973bed7-8f53-4d55-966c-f037941e5ef3",
536
+ "metadata": {
537
+ "scrolled": true
538
+ },
539
+ "outputs": [
540
+ {
541
+ "name": "stderr",
542
+ "output_type": "stream",
543
+ "text": [
544
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
545
+ " warnings.warn(\n",
546
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:61: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
547
+ " warnings.warn(\n",
548
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
549
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
550
+ ]
551
+ },
552
+ {
553
+ "data": {
554
+ "text/html": [
555
+ "\n",
556
+ " <div>\n",
557
+ " \n",
558
+ " <progress value='1860' max='1860' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
559
+ " [1860/1860 7:48:09, Epoch 3/3]\n",
560
+ " </div>\n",
561
+ " <table border=\"1\" class=\"dataframe\">\n",
562
+ " <thead>\n",
563
+ " <tr style=\"text-align: left;\">\n",
564
+ " <th>Epoch</th>\n",
565
+ " <th>Training Loss</th>\n",
566
+ " <th>Validation Loss</th>\n",
567
+ " </tr>\n",
568
+ " </thead>\n",
569
+ " <tbody>\n",
570
+ " <tr>\n",
571
+ " <td>1</td>\n",
572
+ " <td>0.341900</td>\n",
573
+ " <td>0.264468</td>\n",
574
+ " </tr>\n",
575
+ " <tr>\n",
576
+ " <td>2</td>\n",
577
+ " <td>0.259600</td>\n",
578
+ " <td>0.248249</td>\n",
579
+ " </tr>\n",
580
+ " <tr>\n",
581
+ " <td>3</td>\n",
582
+ " <td>0.214400</td>\n",
583
+ " <td>0.248773</td>\n",
584
+ " </tr>\n",
585
+ " </tbody>\n",
586
+ "</table><p>"
587
+ ],
588
+ "text/plain": [
589
+ "<IPython.core.display.HTML object>"
590
+ ]
591
+ },
592
+ "metadata": {},
593
+ "output_type": "display_data"
594
+ },
595
+ {
596
+ "name": "stderr",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
600
+ " warnings.warn(\n",
601
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:61: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
602
+ " warnings.warn(\n",
603
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
604
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
605
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
606
+ " warnings.warn(\n",
607
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:61: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
608
+ " warnings.warn(\n",
609
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
610
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
611
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
612
+ " warnings.warn(\n",
613
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/torch/utils/checkpoint.py:61: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
614
+ " warnings.warn(\n",
615
+ "/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
616
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
617
+ ]
618
+ },
619
+ {
620
+ "data": {
621
+ "text/plain": [
622
+ "TrainOutput(global_step=1860, training_loss=0.32258521408163093, metrics={'train_runtime': 28110.1202, 'train_samples_per_second': 4.23, 'train_steps_per_second': 0.066, 'total_flos': 2.531417002463232e+20, 'train_loss': 0.32258521408163093, 'epoch': 3.0})"
623
+ ]
624
+ },
625
+ "execution_count": 20,
626
+ "metadata": {},
627
+ "output_type": "execute_result"
628
+ }
629
+ ],
630
+ "source": [
631
+ "trainer.train()"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "markdown",
636
+ "id": "620992c3-64f5-48f9-8e66-fdc5f6a27427",
637
+ "metadata": {},
638
+ "source": [
639
+ "### 保存 LoRA 模型"
640
+ ]
641
+ },
642
+ {
643
+ "cell_type": "code",
644
+ "execution_count": 21,
645
+ "id": "53310565-7313-46a7-acf1-215970fd4f8e",
646
+ "metadata": {},
647
+ "outputs": [],
648
+ "source": [
649
+ "model.save_pretrained(\"models/whisper-large-v2-asr-int8\")"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "markdown",
654
+ "id": "dcfe9611-eee5-462f-8cb8-fed86eec76e0",
655
+ "metadata": {},
656
+ "source": [
657
+ "### 使用 Pipiline 加载 LoRA 模型,实现自动语音识别任务"
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "code",
662
+ "execution_count": 3,
663
+ "id": "426d7520-62cb-42bb-a4dd-000aa607b105",
664
+ "metadata": {},
665
+ "outputs": [
666
+ {
667
+ "data": {
668
+ "text/plain": [
669
+ "5.763907432556152"
670
+ ]
671
+ },
672
+ "execution_count": 3,
673
+ "metadata": {},
674
+ "output_type": "execute_result"
675
+ }
676
+ ],
677
+ "source": [
678
+ "from transformers import AutoModelForSpeechSeq2Seq\n",
679
+ "\n",
680
+ "my_model_name_or_path = \"yqzhangjx/whisper-large-v2-asr-int8\"\n",
681
+ "\n",
682
+ "model = AutoModelForSpeechSeq2Seq.from_pretrained(my_model_name_or_path, device_map=\"auto\")\n",
683
+ "\n",
684
+ "model.get_memory_footprint()/1024**3"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 4,
690
+ "id": "7536c488-0526-4c12-baaf-b4f7c7075be6",
691
+ "metadata": {},
692
+ "outputs": [],
693
+ "source": [
694
+ "from transformers import AutoFeatureExtractor, AutoTokenizer, AutoProcessor\n",
695
+ "\n",
696
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)\n",
697
+ "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, language=language, task=task)\n",
698
+ "processor = AutoProcessor.from_pretrained(model_name_or_path, language=language, task=task)"
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "execution_count": 5,
704
+ "id": "18181692-a143-44ee-b56c-e754d308e0ec",
705
+ "metadata": {},
706
+ "outputs": [],
707
+ "source": [
708
+ "test_audio = \"data/audio/test_zh.flac\""
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "code",
713
+ "execution_count": 6,
714
+ "id": "9d494647-082c-4e48-9486-7945618ae679",
715
+ "metadata": {},
716
+ "outputs": [],
717
+ "source": [
718
+ "from transformers import AutomaticSpeechRecognitionPipeline\n",
719
+ "\n",
720
+ "pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
721
+ "\n",
722
+ "forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"chinese\", task=task)"
723
+ ]
724
+ },
725
+ {
726
+ "cell_type": "code",
727
+ "execution_count": 7,
728
+ "id": "c3eac486-169f-41ad-b9c3-69f2c27c3e1f",
729
+ "metadata": {},
730
+ "outputs": [],
731
+ "source": [
732
+ "with torch.cuda.amp.autocast():\n",
733
+ " text = pipeline(test_audio, generate_kwargs={\"forced_decoder_ids\": forced_decoder_ids}, max_new_tokens=255)[\"text\"]"
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "code",
738
+ "execution_count": 8,
739
+ "id": "cc4b24b2-c65b-4e63-8f16-26c72039a38d",
740
+ "metadata": {},
741
+ "outputs": [
742
+ {
743
+ "data": {
744
+ "text/plain": [
745
+ "'这是一段测试用于Whisper Large V2模型的自动语音识别测试。'"
746
+ ]
747
+ },
748
+ "execution_count": 8,
749
+ "metadata": {},
750
+ "output_type": "execute_result"
751
+ }
752
+ ],
753
+ "source": [
754
+ "text"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": null,
760
+ "id": "89f49787-6ab4-4bc1-91b8-a1c104c9feaf",
761
+ "metadata": {},
762
+ "outputs": [],
763
+ "source": []
764
+ },
765
+ {
766
+ "cell_type": "markdown",
767
+ "id": "0285dd19-229e-4241-b680-71e25ab51dde",
768
+ "metadata": {},
769
+ "source": [
770
+ "#### Homework 1: 为中文语料的训练过程增加过程评估,观察 Train Loss 和 Validation Loss 变化;\n",
771
+ "#### Homework 2: LoRA 模型训练完成后,使用测试集进行完整的模型评估"
772
+ ]
773
+ },
774
+ {
775
+ "cell_type": "code",
776
+ "execution_count": null,
777
+ "id": "0c90ad4c-70eb-43d1-96ec-cc74c4bae345",
778
+ "metadata": {},
779
+ "outputs": [],
780
+ "source": []
781
+ },
782
+ {
783
+ "cell_type": "markdown",
784
+ "id": "b24fccce-fec3-48a3-b43c-b9077788521d",
785
+ "metadata": {},
786
+ "source": [
787
+ "## 评估模型"
788
+ ]
789
+ },
790
+ {
791
+ "cell_type": "code",
792
+ "execution_count": 26,
793
+ "id": "b021c6a8-645d-44f7-970b-f410180787a6",
794
+ "metadata": {},
795
+ "outputs": [
796
+ {
797
+ "data": {
798
+ "application/vnd.jupyter.widget-view+json": {
799
+ "model_id": "5874cd7f7cba4c1d9d89092923e0b8a5",
800
+ "version_major": 2,
801
+ "version_minor": 0
802
+ },
803
+ "text/plain": [
804
+ "Downloading builder script: 0%| | 0.00/4.49k [00:00<?, ?B/s]"
805
+ ]
806
+ },
807
+ "metadata": {},
808
+ "output_type": "display_data"
809
+ }
810
+ ],
811
+ "source": [
812
+ "import evaluate\n",
813
+ "\n",
814
+ "# 词错误率(WER)是评估ASR模型常用的指标。从 Evaluate加载 WER 指标\n",
815
+ "metric = evaluate.load(\"wer\")"
816
+ ]
817
+ },
818
+ {
819
+ "cell_type": "code",
820
+ "execution_count": 27,
821
+ "id": "5156ce16-0e04-4e41-b308-52c6c9b2a20d",
822
+ "metadata": {
823
+ "scrolled": true
824
+ },
825
+ "outputs": [
826
+ {
827
+ "data": {
828
+ "text/plain": [
829
+ "PeftModel(\n",
830
+ " (base_model): LoraModel(\n",
831
+ " (model): WhisperForConditionalGeneration(\n",
832
+ " (model): WhisperModel(\n",
833
+ " (encoder): WhisperEncoder(\n",
834
+ " (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n",
835
+ " (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n",
836
+ " (embed_positions): Embedding(1500, 1280)\n",
837
+ " (layers): ModuleList(\n",
838
+ " (0-31): 32 x WhisperEncoderLayer(\n",
839
+ " (self_attn): WhisperSdpaAttention(\n",
840
+ " (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
841
+ " (v_proj): lora.Linear8bitLt(\n",
842
+ " (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
843
+ " (lora_dropout): ModuleDict(\n",
844
+ " (default): Dropout(p=0.05, inplace=False)\n",
845
+ " )\n",
846
+ " (lora_A): ModuleDict(\n",
847
+ " (default): Linear(in_features=1280, out_features=8, bias=False)\n",
848
+ " )\n",
849
+ " (lora_B): ModuleDict(\n",
850
+ " (default): Linear(in_features=8, out_features=1280, bias=False)\n",
851
+ " )\n",
852
+ " (lora_embedding_A): ParameterDict()\n",
853
+ " (lora_embedding_B): ParameterDict()\n",
854
+ " )\n",
855
+ " (q_proj): lora.Linear8bitLt(\n",
856
+ " (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
857
+ " (lora_dropout): ModuleDict(\n",
858
+ " (default): Dropout(p=0.05, inplace=False)\n",
859
+ " )\n",
860
+ " (lora_A): ModuleDict(\n",
861
+ " (default): Linear(in_features=1280, out_features=8, bias=False)\n",
862
+ " )\n",
863
+ " (lora_B): ModuleDict(\n",
864
+ " (default): Linear(in_features=8, out_features=1280, bias=False)\n",
865
+ " )\n",
866
+ " (lora_embedding_A): ParameterDict()\n",
867
+ " (lora_embedding_B): ParameterDict()\n",
868
+ " )\n",
869
+ " (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
870
+ " )\n",
871
+ " (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
872
+ " (activation_fn): GELUActivation()\n",
873
+ " (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
874
+ " (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
875
+ " (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
876
+ " )\n",
877
+ " )\n",
878
+ " (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
879
+ " )\n",
880
+ " (decoder): WhisperDecoder(\n",
881
+ " (embed_tokens): Embedding(51865, 1280, padding_idx=50257)\n",
882
+ " (embed_positions): WhisperPositionalEmbedding(448, 1280)\n",
883
+ " (layers): ModuleList(\n",
884
+ " (0-31): 32 x WhisperDecoderLayer(\n",
885
+ " (self_attn): WhisperSdpaAttention(\n",
886
+ " (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
887
+ " (v_proj): lora.Linear8bitLt(\n",
888
+ " (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
889
+ " (lora_dropout): ModuleDict(\n",
890
+ " (default): Dropout(p=0.05, inplace=False)\n",
891
+ " )\n",
892
+ " (lora_A): ModuleDict(\n",
893
+ " (default): Linear(in_features=1280, out_features=8, bias=False)\n",
894
+ " )\n",
895
+ " (lora_B): ModuleDict(\n",
896
+ " (default): Linear(in_features=8, out_features=1280, bias=False)\n",
897
+ " )\n",
898
+ " (lora_embedding_A): ParameterDict()\n",
899
+ " (lora_embedding_B): ParameterDict()\n",
900
+ " )\n",
901
+ " (q_proj): lora.Linear8bitLt(\n",
902
+ " (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
903
+ " (lora_dropout): ModuleDict(\n",
904
+ " (default): Dropout(p=0.05, inplace=False)\n",
905
+ " )\n",
906
+ " (lora_A): ModuleDict(\n",
907
+ " (default): Linear(in_features=1280, out_features=8, bias=False)\n",
908
+ " )\n",
909
+ " (lora_B): ModuleDict(\n",
910
+ " (default): Linear(in_features=8, out_features=1280, bias=False)\n",
911
+ " )\n",
912
+ " (lora_embedding_A): ParameterDict()\n",
913
+ " (lora_embedding_B): ParameterDict()\n",
914
+ " )\n",
915
+ " (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
916
+ " )\n",
917
+ " (activation_fn): GELUActivation()\n",
918
+ " (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
919
+ " (encoder_attn): WhisperSdpaAttention(\n",
920
+ " (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
921
+ " (v_proj): lora.Linear8bitLt(\n",
922
+ " (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
923
+ " (lora_dropout): ModuleDict(\n",
924
+ " (default): Dropout(p=0.05, inplace=False)\n",
925
+ " )\n",
926
+ " (lora_A): ModuleDict(\n",
927
+ " (default): Linear(in_features=1280, out_features=8, bias=False)\n",
928
+ " )\n",
929
+ " (lora_B): ModuleDict(\n",
930
+ " (default): Linear(in_features=8, out_features=1280, bias=False)\n",
931
+ " )\n",
932
+ " (lora_embedding_A): ParameterDict()\n",
933
+ " (lora_embedding_B): ParameterDict()\n",
934
+ " )\n",
935
+ " (q_proj): lora.Linear8bitLt(\n",
936
+ " (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
937
+ " (lora_dropout): ModuleDict(\n",
938
+ " (default): Dropout(p=0.05, inplace=False)\n",
939
+ " )\n",
940
+ " (lora_A): ModuleDict(\n",
941
+ " (default): Linear(in_features=1280, out_features=8, bias=False)\n",
942
+ " )\n",
943
+ " (lora_B): ModuleDict(\n",
944
+ " (default): Linear(in_features=8, out_features=1280, bias=False)\n",
945
+ " )\n",
946
+ " (lora_embedding_A): ParameterDict()\n",
947
+ " (lora_embedding_B): ParameterDict()\n",
948
+ " )\n",
949
+ " (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
950
+ " )\n",
951
+ " (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
952
+ " (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
953
+ " (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
954
+ " (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
955
+ " )\n",
956
+ " )\n",
957
+ " (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
958
+ " )\n",
959
+ " )\n",
960
+ " (proj_out): Linear(in_features=1280, out_features=51865, bias=False)\n",
961
+ " )\n",
962
+ " )\n",
963
+ ")"
964
+ ]
965
+ },
966
+ "execution_count": 27,
967
+ "metadata": {},
968
+ "output_type": "execute_result"
969
+ }
970
+ ],
971
+ "source": [
972
+ "from torch.utils.data import DataLoader\n",
973
+ "from tqdm import tqdm\n",
974
+ "import numpy as np\n",
975
+ "import gc\n",
976
+ "\n",
977
+ "eval_dataloader = DataLoader(common_voice[\"test\"], batch_size=8, collate_fn=data_collator)\n",
978
+ "\n",
979
+ "model.eval()"
980
+ ]
981
+ },
982
+ {
983
+ "cell_type": "code",
984
+ "execution_count": 28,
985
+ "id": "9120279e-b10a-44ea-9275-2152ec204fae",
986
+ "metadata": {
987
+ "scrolled": true
988
+ },
989
+ "outputs": [
990
+ {
991
+ "name": "stderr",
992
+ "output_type": "stream",
993
+ "text": [
994
+ " 0%| | 0/1323 [00:00<?, ?it/s]/data/miniconda3/envs/jxzhang/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
995
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
996
+ "100%|██████████| 1323/1323 [3:13:07<00:00, 8.76s/it] \n"
997
+ ]
998
+ }
999
+ ],
1000
+ "source": [
1001
+ "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
1002
+ " with torch.cuda.amp.autocast():\n",
1003
+ " with torch.no_grad():\n",
1004
+ " generated_tokens = (\n",
1005
+ " model.generate(\n",
1006
+ " input_features=batch[\"input_features\"].to(\"cuda\"),\n",
1007
+ " decoder_input_ids=batch[\"labels\"][:, :4].to(\"cuda\"),\n",
1008
+ " max_new_tokens=255,\n",
1009
+ " )\n",
1010
+ " .cpu()\n",
1011
+ " .numpy()\n",
1012
+ " )\n",
1013
+ " labels = batch[\"labels\"].cpu().numpy()\n",
1014
+ " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
1015
+ " decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
1016
+ " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
1017
+ " metric.add_batch(\n",
1018
+ " predictions=decoded_preds,\n",
1019
+ " references=decoded_labels,\n",
1020
+ " )\n",
1021
+ " del generated_tokens, labels, batch\n",
1022
+ " gc.collect()"
1023
+ ]
1024
+ },
1025
+ {
1026
+ "cell_type": "code",
1027
+ "execution_count": 29,
1028
+ "id": "aad4d7e8-ee0d-484a-9ae7-490c8b9898a0",
1029
+ "metadata": {},
1030
+ "outputs": [
1031
+ {
1032
+ "name": "stdout",
1033
+ "output_type": "stream",
1034
+ "text": [
1035
+ "wer=54.73445473445473\n"
1036
+ ]
1037
+ }
1038
+ ],
1039
+ "source": [
1040
+ "wer = 100 * metric.compute()\n",
1041
+ "print(f\"{wer=}\")"
1042
+ ]
1043
+ },
1044
+ {
1045
+ "cell_type": "code",
1046
+ "execution_count": 30,
1047
+ "id": "4120d924-c684-44af-8a1d-76aaf506bd08",
1048
+ "metadata": {},
1049
+ "outputs": [
1050
+ {
1051
+ "data": {
1052
+ "application/vnd.jupyter.widget-view+json": {
1053
+ "model_id": "65c2e25ce2434db2a59617adf7439064",
1054
+ "version_major": 2,
1055
+ "version_minor": 0
1056
+ },
1057
+ "text/plain": [
1058
+ "adapter_model.safetensors: 0%| | 0.00/15.8M [00:00<?, ?B/s]"
1059
+ ]
1060
+ },
1061
+ "metadata": {},
1062
+ "output_type": "display_data"
1063
+ },
1064
+ {
1065
+ "data": {
1066
+ "text/plain": [
1067
+ "CommitInfo(commit_url='https://huggingface.co/yqzhangjx/whisper-large-v2-finetune-for-common_voice_11_0/commit/e9a78de42addc1cfb814188de95de3e97f9281c6', commit_message='Upload model', commit_description='', oid='e9a78de42addc1cfb814188de95de3e97f9281c6', pr_url=None, pr_revision=None, pr_num=None)"
1068
+ ]
1069
+ },
1070
+ "execution_count": 30,
1071
+ "metadata": {},
1072
+ "output_type": "execute_result"
1073
+ }
1074
+ ],
1075
+ "source": [
1076
+ "model.push_to_hub('whisper-large-v2-asr-int8')"
1077
+ ]
1078
+ },
1079
+ {
1080
+ "cell_type": "code",
1081
+ "execution_count": null,
1082
+ "id": "b740c690-35ff-4375-b59f-3c6de5155ec0",
1083
+ "metadata": {},
1084
+ "outputs": [],
1085
+ "source": []
1086
+ }
1087
+ ],
1088
+ "metadata": {
1089
+ "kernelspec": {
1090
+ "display_name": "Python 3 (ipykernel)",
1091
+ "language": "python",
1092
+ "name": "python3"
1093
+ },
1094
+ "language_info": {
1095
+ "codemirror_mode": {
1096
+ "name": "ipython",
1097
+ "version": 3
1098
+ },
1099
+ "file_extension": ".py",
1100
+ "mimetype": "text/x-python",
1101
+ "name": "python",
1102
+ "nbconvert_exporter": "python",
1103
+ "pygments_lexer": "ipython3",
1104
+ "version": "3.9.18"
1105
+ }
1106
+ },
1107
+ "nbformat": 4,
1108
+ "nbformat_minor": 5
1109
+ }