maxseats commited on
Commit
2e9613b
โ€ข
1 Parent(s): 6ebb17f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +281 -3
README.md CHANGED
@@ -11,10 +11,288 @@ metrics:
11
  ---
12
  # Model Name : maxseats/SungBeom-whisper-small-ko-set0
13
  # Description
 
14
 
15
- - ํŒŒ์ธํŠœ๋‹ ๋ฐ์ดํ„ฐ์…‹ : maxseats/aihub-464-preprocessed-680GB-set-0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # ์„ค๋ช…
18
- - ์ฃผ์š” ์˜์—ญ๋ณ„ ํšŒ์˜ ์Œ์„ฑ ๋ฐ์ดํ„ฐ์…‹ 680GB ์ค‘ ์ฒซ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ(10GB)๋ฅผ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
19
- - ๋งํฌ : https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
11
  ---
12
  # Model Name : maxseats/SungBeom-whisper-small-ko-set0
13
  # Description
14
+ - ํŒŒ์ธํŠœ๋‹ ๋ฐ์ดํ„ฐ์…‹ : maxseats/aihub-464-preprocessed-680GB-set-1
15
 
16
+ # ์„ค๋ช…
17
+ - AI hub์˜ ์ฃผ์š” ์˜์—ญ๋ณ„ ํšŒ์˜ ์Œ์„ฑ ๋ฐ์ดํ„ฐ์…‹์„ ํ•™์Šต ์ค‘์ด์—์š”.
18
+ - 680GB ์ค‘ ์ฒซ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ(10GB)๋ฅผ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™€์„œ, ๋‘๋ฒˆ์งธ ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
19
+ - ๋งํฌ : https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-0, https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-1
20
+
21
+ - ๋‹ค์Œ ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด ์ž‘์„ฑํ–ˆ์–ด์š”.
22
+
23
+ ```
24
+ from datasets import load_dataset
25
+ import torch
26
+ from dataclasses import dataclass
27
+ from typing import Any, Dict, List, Union
28
+ import evaluate
29
+ from transformers import WhisperTokenizer, WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
30
+ import mlflow
31
+ from mlflow.tracking.client import MlflowClient
32
+ import subprocess
33
+ from huggingface_hub import create_repo, Repository
34
+ import os
35
+ import shutil
36
+ import math # ์ž„์‹œ ํ…Œ์ŠคํŠธ์šฉ
37
+ model_dir = "./tmpp" # ์ˆ˜์ • X
38
+
39
+
40
+ #########################################################################################################################################
41
+ ################################################### ์‚ฌ์šฉ์ž ์„ค์ • ๋ณ€์ˆ˜ #####################################################################
42
+ #########################################################################################################################################
43
+
44
+ model_description = """
45
+ - ํŒŒ์ธํŠœ๋‹ ๋ฐ์ดํ„ฐ์…‹ : maxseats/aihub-464-preprocessed-680GB-set-1
46
 
47
  # ์„ค๋ช…
48
+ - AI hub์˜ ์ฃผ์š” ์˜์—ญ๋ณ„ ํšŒ์˜ ์Œ์„ฑ ๋ฐ์ดํ„ฐ์…‹์„ ํ•™์Šต ์ค‘์ด์—์š”.
49
+ - 680GB ์ค‘ ์ฒซ๋ฒˆ์งธ ๋ฐ์ดํ„ฐ(10GB)๋ฅผ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™€์„œ, ๋‘๋ฒˆ์งธ ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
50
+ - ๋งํฌ : https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-0, https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-1
51
+ """
52
+
53
+ # model_name = "openai/whisper-base"
54
+ model_name = "maxseats/SungBeom-whisper-small-ko-set0" # ๋Œ€์•ˆ : "SungBeom/whisper-small-ko"
55
+ # dataset_name = "maxseats/aihub-464-preprocessed-680GB-set-1" # ๋ถˆ๋Ÿฌ์˜ฌ ๋ฐ์ดํ„ฐ์…‹(ํ—ˆ๊น…ํŽ˜์ด์Šค ๊ธฐ์ค€)
56
+ dataset_name = "maxseats/aihub-464-preprocessed-680GB-set-1" # ๋ถˆ๋Ÿฌ์˜ฌ ๋ฐ์ดํ„ฐ์…‹(ํ—ˆ๊น…ํŽ˜์ด์Šค ๊ธฐ์ค€)
57
+
58
+ CACHE_DIR = '/mnt/a/maxseats/.finetuning_cache' # ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์ง€์ •
59
+ is_test = False # True: ์†Œ๋Ÿ‰์˜ ์ƒ˜ํ”Œ ๋ฐ์ดํ„ฐ๋กœ ํ…Œ์ŠคํŠธ, False: ์‹ค์ œ ํŒŒ์ธํŠœ๋‹
60
+
61
+ token = "hf_" # ํ—ˆ๊น…ํŽ˜์ด์Šค ํ† ํฐ ์ž…๋ ฅ
62
+
63
+ training_args = Seq2SeqTrainingArguments(
64
+ output_dir=model_dir, # ์›ํ•˜๋Š” ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ์ด๋ฆ„์„ ์ž…๋ ฅํ•œ๋‹ค.
65
+ per_device_train_batch_size=16,
66
+ gradient_accumulation_steps=2, # ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ 2๋ฐฐ ๊ฐ์†Œํ•  ๋•Œ๋งˆ๋‹ค 2๋ฐฐ์”ฉ ์ฆ๊ฐ€
67
+ learning_rate=1e-5,
68
+ warmup_steps=500,
69
+ # max_steps=2, # epoch ๋Œ€์‹  ์„ค์ •
70
+ num_train_epochs=1, # epoch ์ˆ˜ ์„ค์ • / max_steps์™€ ์ด๊ฒƒ ์ค‘ ํ•˜๋‚˜๋งŒ ์„ค์ •
71
+ gradient_checkpointing=True,
72
+ fp16=True,
73
+ evaluation_strategy="steps",
74
+ per_device_eval_batch_size=16,
75
+ predict_with_generate=True,
76
+ generation_max_length=225,
77
+ save_steps=1000,
78
+ eval_steps=1000,
79
+ logging_steps=25,
80
+ report_to=["tensorboard"],
81
+ load_best_model_at_end=True,
82
+ metric_for_best_model="cer", # ํ•œ๊ตญ์–ด์˜ ๊ฒฝ์šฐ 'wer'๋ณด๋‹ค๋Š” 'cer'์ด ๋” ์ ํ•ฉํ•  ๊ฒƒ
83
+ greater_is_better=False,
84
+ push_to_hub=True,
85
+ save_total_limit=5, # ์ตœ๋Œ€ ์ €์žฅํ•  ๋ชจ๋ธ ์ˆ˜ ์ง€์ •
86
+ )
87
+
88
+ #########################################################################################################################################
89
+ ################################################### ์‚ฌ์šฉ์ž ์„ค์ • ๋ณ€์ˆ˜ #####################################################################
90
+ #########################################################################################################################################
91
+
92
+
93
+ @dataclass
94
+ class DataCollatorSpeechSeq2SeqWithPadding:
95
+ processor: Any
96
+
97
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
98
+ # ์ธํ’‹ ๋ฐ์ดํ„ฐ์™€ ๋ผ๋ฒจ ๋ฐ์ดํ„ฐ์˜ ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅด๋ฉฐ, ๋”ฐ๋ผ์„œ ์„œ๋กœ ๋‹ค๋ฅธ ํŒจ๋”ฉ ๋ฐฉ๋ฒ•์ด ์ ์šฉ๋˜์–ด์•ผ ํ•œ๋‹ค. ๊ทธ๋Ÿฌ๋ฏ€๋กœ ๋‘ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„๋ฆฌํ•ด์•ผ ํ•œ๋‹ค.
99
+ # ๋จผ์ € ์˜ค๋””์˜ค ์ธํ’‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ„๋‹จํžˆ ํ† ์น˜ ํ…์„œ๋กœ ๋ฐ˜ํ™˜ํ•˜๋Š” ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.
100
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
101
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
102
+
103
+ # Tokenize๋œ ๋ ˆ์ด๋ธ” ์‹œํ€€์Šค๋ฅผ ๊ฐ€์ ธ์˜จ๋‹ค.
104
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
105
+ # ๋ ˆ์ด๋ธ” ์‹œํ€€์Šค์— ๋Œ€ํ•ด ์ตœ๋Œ€ ๊ธธ์ด๋งŒํผ ํŒจ๋”ฉ ์ž‘์—…์„ ์‹ค์‹œํ•œ๋‹ค.
106
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
107
+
108
+ # ํŒจ๋”ฉ ํ† ํฐ์„ -100์œผ๋กœ ์น˜ํ™˜ํ•˜์—ฌ loss ๊ณ„์‚ฐ ๊ณผ์ •์—์„œ ๋ฌด์‹œ๋˜๋„๋ก ํ•œ๋‹ค.
109
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
110
+
111
+ # ์ด์ „ ํ† ํฌ๋‚˜์ด์ฆˆ ๊ณผ์ •์—์„œ bos ํ† ํฐ์ด ์ถ”๊ฐ€๋˜์—ˆ๋‹ค๋ฉด bos ํ† ํฐ์„ ์ž˜๋ผ๋‚ธ๋‹ค.
112
+ # ํ•ด๋‹น ํ† ํฐ์€ ์ดํ›„ ์–ธ์ œ๋“  ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ๋‹ค.
113
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
114
+ labels = labels[:, 1:]
115
+
116
+ batch["labels"] = labels
117
+
118
+ return batch
119
+
120
+
121
+ def compute_metrics(pred):
122
+ pred_ids = pred.predictions
123
+ label_ids = pred.label_ids
124
+
125
+ # pad_token์„ -100์œผ๋กœ ์น˜ํ™˜
126
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
127
+
128
+ # metrics ๊ณ„์‚ฐ ์‹œ special token๋“ค์„ ๋นผ๊ณ  ๊ณ„์‚ฐํ•˜๋„๋ก ์„ค์ •
129
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
130
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
131
+
132
+ cer = 100 * metric.compute(predictions=pred_str, references=label_str)
133
+
134
+ return {"cer": cer}
135
+
136
+
137
+ # model_dir, ./repo ์ดˆ๊ธฐํ™”
138
+ if os.path.exists(model_dir):
139
+ shutil.rmtree(model_dir)
140
+ os.makedirs(model_dir)
141
+
142
+ if os.path.exists('./repo'):
143
+ shutil.rmtree('./repo')
144
+ os.makedirs('./repo')
145
+
146
+ # ํŒŒ์ธํŠœ๋‹์„ ์ง„ํ–‰ํ•˜๊ณ ์ž ํ•˜๋Š” ๋ชจ๋ธ์˜ processor, tokenizer, feature extractor, model ๋กœ๋“œ
147
+ processor = WhisperProcessor.from_pretrained(model_name, language="Korean", task="transcribe")
148
+ tokenizer = WhisperTokenizer.from_pretrained(model_name, language="Korean", task="transcribe")
149
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
150
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
151
+
152
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
153
+ metric = evaluate.load('cer')
154
+ model.config.forced_decoder_ids = None
155
+ model.config.suppress_tokens = []
156
+
157
+
158
+ # Hub๋กœ๋ถ€ํ„ฐ "๋ชจ๋“  ์ „์ฒ˜๋ฆฌ๊ฐ€ ์™„๋ฃŒ๋œ" ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œ(์ด๊ฒŒ ์ง„์งœ ์˜ค๋ž˜๊ฑธ๋ ค์š”.)
159
+ preprocessed_dataset = load_dataset(dataset_name, cache_dir=CACHE_DIR)
160
+
161
+ # 30%๊นŒ์ง€์˜ valid ๋ฐ์ดํ„ฐ์…‹ ์„ ํƒ(์ฝ”๋“œ ์ž‘๋™ ํ…Œ์ŠคํŠธ๋ฅผ ์œ„ํ•จ)
162
+ if is_test:
163
+ preprocessed_dataset["valid"] = preprocessed_dataset["valid"].select(range(math.ceil(len(preprocessed_dataset) * 0.3)))
164
+
165
+ # training_args ๊ฐ์ฒด๋ฅผ JSON ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
166
+ training_args_dict = training_args.to_dict()
167
+
168
+ # MLflow UI ๊ด€๋ฆฌ ํด๋” ์ง€์ •
169
+ mlflow.set_tracking_uri("sqlite:////content/drive/MyDrive/STT_test/mlflow.db")
170
+
171
+ # MLflow ์‹คํ—˜ ์ด๋ฆ„์„ ๋ชจ๋ธ ์ด๋ฆ„์œผ๋กœ ์„ค์ •
172
+ experiment_name = model_name
173
+ existing_experiment = mlflow.get_experiment_by_name(experiment_name)
174
+
175
+ if existing_experiment is not None:
176
+ experiment_id = existing_experiment.experiment_id
177
+ else:
178
+ experiment_id = mlflow.create_experiment(experiment_name)
179
+
180
+
181
+ model_version = 1 # ๋กœ๊น… ํ•˜๋ ค๋Š” ๋ชจ๋ธ ๋ฒ„์ „(์ด๋ฏธ ์กด์žฌํ•˜๋ฉด, ์ž๋™ ํ• ๋‹น)
182
+
183
+ # MLflow ๋กœ๊น…
184
+ with mlflow.start_run(experiment_id=experiment_id, description=model_description):
185
+ # training_args ๋กœ๊น…
186
+ for key, value in training_args_dict.items():
187
+ mlflow.log_param(key, value)
188
+
189
+
190
+ mlflow.set_tag("Dataset", dataset_name) # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๊น…
191
+
192
+ trainer = Seq2SeqTrainer(
193
+ args=training_args,
194
+ model=model,
195
+ train_dataset=preprocessed_dataset["train"],
196
+ eval_dataset=preprocessed_dataset["valid"], # or "test"
197
+ data_collator=data_collator,
198
+ compute_metrics=compute_metrics,
199
+ tokenizer=processor.feature_extractor,
200
+ )
201
+
202
+ trainer.train()
203
+ trainer.save_model(model_dir) # ํ•™์Šต ํ›„ ๋ชจ๋ธ ์ €์žฅ
204
+
205
+ # Metric ๋กœ๊น…
206
+ metrics = trainer.evaluate()
207
+ for metric_name, metric_value in metrics.items():
208
+ mlflow.log_metric(metric_name, metric_value)
209
+
210
+ # MLflow ๋ชจ๋ธ ๋ ˆ์ง€์Šคํ„ฐ
211
+ model_uri = "runs:/{run_id}/{artifact_path}".format(run_id=mlflow.active_run().info.run_id, artifact_path=model_dir)
212
+
213
+ # ์ด ๊ฐ’ ์ด์šฉํ•ด์„œ ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ชจ๋ธ ์ด๋ฆ„ ์„ค์ • ์˜ˆ์ •
214
+ model_details = mlflow.register_model(model_uri=model_uri, name=model_name.replace('/', '-')) # ๋ชจ๋ธ ์ด๋ฆ„์— '/'๋ฅผ '-'๋กœ ๋Œ€์ฒด
215
+
216
+ # ๋ชจ๋ธ Description
217
+ client = MlflowClient()
218
+ client.update_model_version(name=model_details.name, version=model_details.version, description=model_description)
219
+ model_version = model_details.version # ๋ฒ„์ „ ์ •๋ณด ํ—ˆ๊น…ํŽ˜์ด์Šค ์—…๋กœ๋“œ ์‹œ ์‚ฌ์šฉ
220
+
221
+
222
+
223
+ ## ํ—ˆ๊น…ํŽ˜์ด์Šค ๋กœ๊ทธ์ธ
224
+ while True:
225
+
226
+ if token =="exit":
227
+ break
228
+
229
+ try:
230
+ result = subprocess.run(["huggingface-cli", "login", "--token", token])
231
+ if result.returncode != 0:
232
+ raise Exception()
233
+ break
234
+ except Exception as e:
235
+ token = input("Please enter your Hugging Face API token: ")
236
+
237
+
238
+ os.environ["HUGGINGFACE_HUB_TOKEN"] = token
239
+
240
+ # ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ์ด๋ฆ„ ์„ค์ •
241
+ repo_name = "maxseats/" + model_name.replace('/', '-') + '-' + str(model_version) # ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ ˆํฌ์ง€ํ† ๋ฆฌ ์ด๋ฆ„ ์„ค์ •
242
+
243
+ # ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ์ƒ์„ฑ
244
+ create_repo(repo_name, exist_ok=True, token=token)
245
+
246
+
247
+
248
+ # ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ํด๋ก 
249
+ repo = Repository(local_dir='./repo', clone_from=f"{repo_name}", use_auth_token=token)
250
+
251
+
252
+ # model_dir ํ•„์š”ํ•œ ํŒŒ์ผ ๋ณต์‚ฌ
253
+ max_depth = 1 # ์ˆœํšŒํ•  ์ตœ๋Œ€ ๊นŠ์ด
254
+
255
+ for root, dirs, files in os.walk(model_dir):
256
+ depth = root.count(os.sep) - model_dir.count(os.sep)
257
+ if depth < max_depth:
258
+ for file in files:
259
+ # ํŒŒ์ผ ๊ฒฝ๋กœ ์ƒ์„ฑ
260
+ source_file = os.path.join(root, file)
261
+ # ๋Œ€์ƒ ํด๋”์— ๋ณต์‚ฌ
262
+ shutil.copy(source_file, './repo')
263
+
264
+
265
+ # ํ† ํฌ๋‚˜์ด์ € ๋‹ค์šด๋กœ๋“œ ๋ฐ ๋กœ์ปฌ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ
266
+ tokenizer.save_pretrained('./repo')
267
+
268
+
269
+ readme = f"""
270
+ ---
271
+ language: ko
272
+ tags:
273
+ - whisper
274
+ - speech-recognition
275
+ datasets:
276
+ - {dataset_name}
277
+ metrics:
278
+ - cer
279
+ ---
280
+ # Model Name : {model_name}
281
+ # Description
282
+ {model_description}
283
+ """
284
+
285
+
286
+ # ๋ชจ๋ธ ์นด๋“œ ๋ฐ ๊ธฐํƒ€ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์ผ ์ž‘์„ฑ
287
+ with open("./repo/README.md", "w") as f:
288
+ f.write(readme)
289
+
290
+ # ํŒŒ์ผ ์ปค๋ฐ‹ ํ‘ธ์‹œ
291
+ repo.push_to_hub(commit_message="Initial commit")
292
+
293
+ # ํด๋”์™€ ํ•˜์œ„ ๋‚ด์šฉ ์‚ญ์ œ
294
+ shutil.rmtree(model_dir)
295
+ shutil.rmtree('./repo')
296
+ ```
297
+
298