Update README.md
Browse files
README.md
CHANGED
@@ -11,10 +11,288 @@ metrics:
|
|
11 |
---
|
12 |
# Model Name : maxseats/SungBeom-whisper-small-ko-set0
|
13 |
# Description
|
|
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# ์ค๋ช
|
18 |
-
- ์ฃผ์ ์์ญ๋ณ ํ์ ์์ฑ
|
19 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
|
|
11 |
---
|
12 |
# Model Name : maxseats/SungBeom-whisper-small-ko-set0
|
13 |
# Description
|
14 |
+
- ํ์ธํ๋ ๋ฐ์ดํฐ์
: maxseats/aihub-464-preprocessed-680GB-set-1
|
15 |
|
16 |
+
# ์ค๋ช
|
17 |
+
- AI hub์ ์ฃผ์ ์์ญ๋ณ ํ์ ์์ฑ ๋ฐ์ดํฐ์
์ ํ์ต ์ค์ด์์.
|
18 |
+
- 680GB ์ค ์ฒซ๋ฒ์งธ ๋ฐ์ดํฐ(10GB)๋ฅผ ํ์ธํ๋ํ ๋ชจ๋ธ์ ๋ถ๋ฌ์์, ๋๋ฒ์งธ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ ๋ชจ๋ธ์
๋๋ค.
|
19 |
+
- ๋งํฌ : https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-0, https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-1
|
20 |
+
|
21 |
+
- ๋ค์ ์ฝ๋๋ฅผ ํตํด ์์ฑํ์ด์.
|
22 |
+
|
23 |
+
```
|
24 |
+
from datasets import load_dataset
|
25 |
+
import torch
|
26 |
+
from dataclasses import dataclass
|
27 |
+
from typing import Any, Dict, List, Union
|
28 |
+
import evaluate
|
29 |
+
from transformers import WhisperTokenizer, WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
|
30 |
+
import mlflow
|
31 |
+
from mlflow.tracking.client import MlflowClient
|
32 |
+
import subprocess
|
33 |
+
from huggingface_hub import create_repo, Repository
|
34 |
+
import os
|
35 |
+
import shutil
|
36 |
+
import math # ์์ ํ
์คํธ์ฉ
|
37 |
+
model_dir = "./tmpp" # ์์ X
|
38 |
+
|
39 |
+
|
40 |
+
#########################################################################################################################################
|
41 |
+
################################################### ์ฌ์ฉ์ ์ค์ ๋ณ์ #####################################################################
|
42 |
+
#########################################################################################################################################
|
43 |
+
|
44 |
+
model_description = """
|
45 |
+
- ํ์ธํ๋ ๋ฐ์ดํฐ์
: maxseats/aihub-464-preprocessed-680GB-set-1
|
46 |
|
47 |
# ์ค๋ช
|
48 |
+
- AI hub์ ์ฃผ์ ์์ญ๋ณ ํ์ ์์ฑ ๋ฐ์ดํฐ์
์ ํ์ต ์ค์ด์์.
|
49 |
+
- 680GB ์ค ์ฒซ๋ฒ์งธ ๋ฐ์ดํฐ(10GB)๋ฅผ ํ์ธํ๋ํ ๋ชจ๋ธ์ ๋ถ๋ฌ์์, ๋๋ฒ์งธ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ ๋ชจ๋ธ์
๋๋ค.
|
50 |
+
- ๋งํฌ : https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-0, https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-1
|
51 |
+
"""
|
52 |
+
|
53 |
+
# model_name = "openai/whisper-base"
|
54 |
+
model_name = "maxseats/SungBeom-whisper-small-ko-set0" # ๋์ : "SungBeom/whisper-small-ko"
|
55 |
+
# dataset_name = "maxseats/aihub-464-preprocessed-680GB-set-1" # ๋ถ๋ฌ์ฌ ๋ฐ์ดํฐ์
(ํ๊น
ํ์ด์ค ๊ธฐ์ค)
|
56 |
+
dataset_name = "maxseats/aihub-464-preprocessed-680GB-set-1" # ๋ถ๋ฌ์ฌ ๋ฐ์ดํฐ์
(ํ๊น
ํ์ด์ค ๊ธฐ์ค)
|
57 |
+
|
58 |
+
CACHE_DIR = '/mnt/a/maxseats/.finetuning_cache' # ์บ์ ๋๋ ํ ๋ฆฌ ์ง์
|
59 |
+
is_test = False # True: ์๋์ ์ํ ๋ฐ์ดํฐ๋ก ํ
์คํธ, False: ์ค์ ํ์ธํ๋
|
60 |
+
|
61 |
+
token = "hf_" # ํ๊น
ํ์ด์ค ํ ํฐ ์
๋ ฅ
|
62 |
+
|
63 |
+
training_args = Seq2SeqTrainingArguments(
|
64 |
+
output_dir=model_dir, # ์ํ๋ ๋ฆฌํฌ์งํ ๋ฆฌ ์ด๋ฆ์ ์
๋ ฅํ๋ค.
|
65 |
+
per_device_train_batch_size=16,
|
66 |
+
gradient_accumulation_steps=2, # ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 2๋ฐฐ ๊ฐ์ํ ๋๋ง๋ค 2๋ฐฐ์ฉ ์ฆ๊ฐ
|
67 |
+
learning_rate=1e-5,
|
68 |
+
warmup_steps=500,
|
69 |
+
# max_steps=2, # epoch ๋์ ์ค์
|
70 |
+
num_train_epochs=1, # epoch ์ ์ค์ / max_steps์ ์ด๊ฒ ์ค ํ๋๋ง ์ค์
|
71 |
+
gradient_checkpointing=True,
|
72 |
+
fp16=True,
|
73 |
+
evaluation_strategy="steps",
|
74 |
+
per_device_eval_batch_size=16,
|
75 |
+
predict_with_generate=True,
|
76 |
+
generation_max_length=225,
|
77 |
+
save_steps=1000,
|
78 |
+
eval_steps=1000,
|
79 |
+
logging_steps=25,
|
80 |
+
report_to=["tensorboard"],
|
81 |
+
load_best_model_at_end=True,
|
82 |
+
metric_for_best_model="cer", # ํ๊ตญ์ด์ ๊ฒฝ์ฐ 'wer'๋ณด๋ค๋ 'cer'์ด ๋ ์ ํฉํ ๊ฒ
|
83 |
+
greater_is_better=False,
|
84 |
+
push_to_hub=True,
|
85 |
+
save_total_limit=5, # ์ต๋ ์ ์ฅํ ๋ชจ๋ธ ์ ์ง์
|
86 |
+
)
|
87 |
+
|
88 |
+
#########################################################################################################################################
|
89 |
+
################################################### ์ฌ์ฉ์ ์ค์ ๋ณ์ #####################################################################
|
90 |
+
#########################################################################################################################################
|
91 |
+
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
95 |
+
processor: Any
|
96 |
+
|
97 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
98 |
+
# ์ธํ ๋ฐ์ดํฐ์ ๋ผ๋ฒจ ๋ฐ์ดํฐ์ ๊ธธ์ด๊ฐ ๋ค๋ฅด๋ฉฐ, ๋ฐ๋ผ์ ์๋ก ๋ค๋ฅธ ํจ๋ฉ ๋ฐฉ๋ฒ์ด ์ ์ฉ๋์ด์ผ ํ๋ค. ๊ทธ๋ฌ๋ฏ๋ก ๋ ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฆฌํด์ผ ํ๋ค.
|
99 |
+
# ๋จผ์ ์ค๋์ค ์ธํ ๋ฐ์ดํฐ๋ฅผ ๊ฐ๋จํ ํ ์น ํ
์๋ก ๋ฐํํ๋ ์์
์ ์ํํ๋ค.
|
100 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
101 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
102 |
+
|
103 |
+
# Tokenize๋ ๋ ์ด๋ธ ์ํ์ค๋ฅผ ๊ฐ์ ธ์จ๋ค.
|
104 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
105 |
+
# ๋ ์ด๋ธ ์ํ์ค์ ๋ํด ์ต๋ ๊ธธ์ด๋งํผ ํจ๋ฉ ์์
์ ์ค์ํ๋ค.
|
106 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
107 |
+
|
108 |
+
# ํจ๋ฉ ํ ํฐ์ -100์ผ๋ก ์นํํ์ฌ loss ๊ณ์ฐ ๊ณผ์ ์์ ๋ฌด์๋๋๋ก ํ๋ค.
|
109 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
110 |
+
|
111 |
+
# ์ด์ ํ ํฌ๋์ด์ฆ ๊ณผ์ ์์ bos ํ ํฐ์ด ์ถ๊ฐ๋์๋ค๋ฉด bos ํ ํฐ์ ์๋ผ๋ธ๋ค.
|
112 |
+
# ํด๋น ํ ํฐ์ ์ดํ ์ธ์ ๋ ์ถ๊ฐํ ์ ์๋ค.
|
113 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
114 |
+
labels = labels[:, 1:]
|
115 |
+
|
116 |
+
batch["labels"] = labels
|
117 |
+
|
118 |
+
return batch
|
119 |
+
|
120 |
+
|
121 |
+
def compute_metrics(pred):
|
122 |
+
pred_ids = pred.predictions
|
123 |
+
label_ids = pred.label_ids
|
124 |
+
|
125 |
+
# pad_token์ -100์ผ๋ก ์นํ
|
126 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
127 |
+
|
128 |
+
# metrics ๊ณ์ฐ ์ special token๋ค์ ๋นผ๊ณ ๊ณ์ฐํ๋๋ก ์ค์
|
129 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
130 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
131 |
+
|
132 |
+
cer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
133 |
+
|
134 |
+
return {"cer": cer}
|
135 |
+
|
136 |
+
|
137 |
+
# model_dir, ./repo ์ด๊ธฐํ
|
138 |
+
if os.path.exists(model_dir):
|
139 |
+
shutil.rmtree(model_dir)
|
140 |
+
os.makedirs(model_dir)
|
141 |
+
|
142 |
+
if os.path.exists('./repo'):
|
143 |
+
shutil.rmtree('./repo')
|
144 |
+
os.makedirs('./repo')
|
145 |
+
|
146 |
+
# ํ์ธํ๋์ ์งํํ๊ณ ์ ํ๋ ๋ชจ๋ธ์ processor, tokenizer, feature extractor, model ๋ก๋
|
147 |
+
processor = WhisperProcessor.from_pretrained(model_name, language="Korean", task="transcribe")
|
148 |
+
tokenizer = WhisperTokenizer.from_pretrained(model_name, language="Korean", task="transcribe")
|
149 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
|
150 |
+
model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
151 |
+
|
152 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
153 |
+
metric = evaluate.load('cer')
|
154 |
+
model.config.forced_decoder_ids = None
|
155 |
+
model.config.suppress_tokens = []
|
156 |
+
|
157 |
+
|
158 |
+
# Hub๋ก๋ถํฐ "๋ชจ๋ ์ ์ฒ๋ฆฌ๊ฐ ์๋ฃ๋" ๋ฐ์ดํฐ์
์ ๋ก๋(์ด๊ฒ ์ง์ง ์ค๋๊ฑธ๋ ค์.)
|
159 |
+
preprocessed_dataset = load_dataset(dataset_name, cache_dir=CACHE_DIR)
|
160 |
+
|
161 |
+
# 30%๊น์ง์ valid ๋ฐ์ดํฐ์
์ ํ(์ฝ๋ ์๋ ํ
์คํธ๋ฅผ ์ํจ)
|
162 |
+
if is_test:
|
163 |
+
preprocessed_dataset["valid"] = preprocessed_dataset["valid"].select(range(math.ceil(len(preprocessed_dataset) * 0.3)))
|
164 |
+
|
165 |
+
# training_args ๊ฐ์ฒด๋ฅผ JSON ํ์์ผ๋ก ๋ณํ
|
166 |
+
training_args_dict = training_args.to_dict()
|
167 |
+
|
168 |
+
# MLflow UI ๊ด๋ฆฌ ํด๋ ์ง์
|
169 |
+
mlflow.set_tracking_uri("sqlite:////content/drive/MyDrive/STT_test/mlflow.db")
|
170 |
+
|
171 |
+
# MLflow ์คํ ์ด๋ฆ์ ๋ชจ๋ธ ์ด๋ฆ์ผ๋ก ์ค์
|
172 |
+
experiment_name = model_name
|
173 |
+
existing_experiment = mlflow.get_experiment_by_name(experiment_name)
|
174 |
+
|
175 |
+
if existing_experiment is not None:
|
176 |
+
experiment_id = existing_experiment.experiment_id
|
177 |
+
else:
|
178 |
+
experiment_id = mlflow.create_experiment(experiment_name)
|
179 |
+
|
180 |
+
|
181 |
+
model_version = 1 # ๋ก๊น
ํ๋ ค๋ ๋ชจ๋ธ ๋ฒ์ (์ด๋ฏธ ์กด์ฌํ๋ฉด, ์๋ ํ ๋น)
|
182 |
+
|
183 |
+
# MLflow ๋ก๊น
|
184 |
+
with mlflow.start_run(experiment_id=experiment_id, description=model_description):
|
185 |
+
# training_args ๋ก๊น
|
186 |
+
for key, value in training_args_dict.items():
|
187 |
+
mlflow.log_param(key, value)
|
188 |
+
|
189 |
+
|
190 |
+
mlflow.set_tag("Dataset", dataset_name) # ๋ฐ์ดํฐ์
๋ก๊น
|
191 |
+
|
192 |
+
trainer = Seq2SeqTrainer(
|
193 |
+
args=training_args,
|
194 |
+
model=model,
|
195 |
+
train_dataset=preprocessed_dataset["train"],
|
196 |
+
eval_dataset=preprocessed_dataset["valid"], # or "test"
|
197 |
+
data_collator=data_collator,
|
198 |
+
compute_metrics=compute_metrics,
|
199 |
+
tokenizer=processor.feature_extractor,
|
200 |
+
)
|
201 |
+
|
202 |
+
trainer.train()
|
203 |
+
trainer.save_model(model_dir) # ํ์ต ํ ๋ชจ๋ธ ์ ์ฅ
|
204 |
+
|
205 |
+
# Metric ๋ก๊น
|
206 |
+
metrics = trainer.evaluate()
|
207 |
+
for metric_name, metric_value in metrics.items():
|
208 |
+
mlflow.log_metric(metric_name, metric_value)
|
209 |
+
|
210 |
+
# MLflow ๋ชจ๋ธ ๋ ์ง์คํฐ
|
211 |
+
model_uri = "runs:/{run_id}/{artifact_path}".format(run_id=mlflow.active_run().info.run_id, artifact_path=model_dir)
|
212 |
+
|
213 |
+
# ์ด ๊ฐ ์ด์ฉํด์ ํ๊น
ํ์ด์ค ๋ชจ๋ธ ์ด๋ฆ ์ค์ ์์
|
214 |
+
model_details = mlflow.register_model(model_uri=model_uri, name=model_name.replace('/', '-')) # ๋ชจ๋ธ ์ด๋ฆ์ '/'๋ฅผ '-'๋ก ๋์ฒด
|
215 |
+
|
216 |
+
# ๋ชจ๋ธ Description
|
217 |
+
client = MlflowClient()
|
218 |
+
client.update_model_version(name=model_details.name, version=model_details.version, description=model_description)
|
219 |
+
model_version = model_details.version # ๋ฒ์ ์ ๋ณด ํ๊น
ํ์ด์ค ์
๋ก๋ ์ ์ฌ์ฉ
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
## ํ๊น
ํ์ด์ค ๋ก๊ทธ์ธ
|
224 |
+
while True:
|
225 |
+
|
226 |
+
if token =="exit":
|
227 |
+
break
|
228 |
+
|
229 |
+
try:
|
230 |
+
result = subprocess.run(["huggingface-cli", "login", "--token", token])
|
231 |
+
if result.returncode != 0:
|
232 |
+
raise Exception()
|
233 |
+
break
|
234 |
+
except Exception as e:
|
235 |
+
token = input("Please enter your Hugging Face API token: ")
|
236 |
+
|
237 |
+
|
238 |
+
os.environ["HUGGINGFACE_HUB_TOKEN"] = token
|
239 |
+
|
240 |
+
# ๋ฆฌํฌ์งํ ๋ฆฌ ์ด๋ฆ ์ค์
|
241 |
+
repo_name = "maxseats/" + model_name.replace('/', '-') + '-' + str(model_version) # ํ๊น
ํ์ด์ค ๋ ํฌ์งํ ๋ฆฌ ์ด๋ฆ ์ค์
|
242 |
+
|
243 |
+
# ๋ฆฌํฌ์งํ ๋ฆฌ ์์ฑ
|
244 |
+
create_repo(repo_name, exist_ok=True, token=token)
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
# ๋ฆฌํฌ์งํ ๋ฆฌ ํด๋ก
|
249 |
+
repo = Repository(local_dir='./repo', clone_from=f"{repo_name}", use_auth_token=token)
|
250 |
+
|
251 |
+
|
252 |
+
# model_dir ํ์ํ ํ์ผ ๋ณต์ฌ
|
253 |
+
max_depth = 1 # ์ํํ ์ต๋ ๊น์ด
|
254 |
+
|
255 |
+
for root, dirs, files in os.walk(model_dir):
|
256 |
+
depth = root.count(os.sep) - model_dir.count(os.sep)
|
257 |
+
if depth < max_depth:
|
258 |
+
for file in files:
|
259 |
+
# ํ์ผ ๊ฒฝ๋ก ์์ฑ
|
260 |
+
source_file = os.path.join(root, file)
|
261 |
+
# ๋์ ํด๋์ ๋ณต์ฌ
|
262 |
+
shutil.copy(source_file, './repo')
|
263 |
+
|
264 |
+
|
265 |
+
# ํ ํฌ๋์ด์ ๋ค์ด๋ก๋ ๋ฐ ๋ก์ปฌ ๋๋ ํ ๋ฆฌ์ ์ ์ฅ
|
266 |
+
tokenizer.save_pretrained('./repo')
|
267 |
+
|
268 |
+
|
269 |
+
readme = f"""
|
270 |
+
---
|
271 |
+
language: ko
|
272 |
+
tags:
|
273 |
+
- whisper
|
274 |
+
- speech-recognition
|
275 |
+
datasets:
|
276 |
+
- {dataset_name}
|
277 |
+
metrics:
|
278 |
+
- cer
|
279 |
+
---
|
280 |
+
# Model Name : {model_name}
|
281 |
+
# Description
|
282 |
+
{model_description}
|
283 |
+
"""
|
284 |
+
|
285 |
+
|
286 |
+
# ๋ชจ๋ธ ์นด๋ ๋ฐ ๊ธฐํ ๋ฉํ๋ฐ์ดํฐ ํ์ผ ์์ฑ
|
287 |
+
with open("./repo/README.md", "w") as f:
|
288 |
+
f.write(readme)
|
289 |
+
|
290 |
+
# ํ์ผ ์ปค๋ฐ ํธ์
|
291 |
+
repo.push_to_hub(commit_message="Initial commit")
|
292 |
+
|
293 |
+
# ํด๋์ ํ์ ๋ด์ฉ ์ญ์
|
294 |
+
shutil.rmtree(model_dir)
|
295 |
+
shutil.rmtree('./repo')
|
296 |
+
```
|
297 |
+
|
298 |
|