maxseats's picture
Update README.md
2e9613b verified
metadata
language: ko
tags:
  - whisper
  - speech-recognition
datasets:
  - maxseats/aihub-464-preprocessed-680GB-set-1
metrics:
  - cer

Model Name : maxseats/SungBeom-whisper-small-ko-set0

Description

  • νŒŒμΈνŠœλ‹ 데이터셋 : maxseats/aihub-464-preprocessed-680GB-set-1

μ„€λͺ…

from datasets import load_dataset
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
from transformers import WhisperTokenizer, WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import mlflow
from mlflow.tracking.client import MlflowClient
import subprocess
from huggingface_hub import create_repo, Repository
import os
import shutil
import math # μž„μ‹œ ν…ŒμŠ€νŠΈμš©
model_dir = "./tmpp" # μˆ˜μ • X


#########################################################################################################################################
################################################### μ‚¬μš©μž μ„€μ • λ³€μˆ˜ #####################################################################
#########################################################################################################################################

model_description = """
- νŒŒμΈνŠœλ‹ 데이터셋 : maxseats/aihub-464-preprocessed-680GB-set-1

# μ„€λͺ…
- AI hub의 μ£Όμš” μ˜μ—­λ³„ 회의 μŒμ„± 데이터셋을 ν•™μŠ΅ μ€‘μ΄μ—μš”.
- 680GB 쀑 첫번째 데이터(10GB)λ₯Ό νŒŒμΈνŠœλ‹ν•œ λͺ¨λΈμ„ λΆˆλŸ¬μ™€μ„œ, λ‘λ²ˆμ§Έ 데이터λ₯Ό ν•™μŠ΅ν•œ λͺ¨λΈμž…λ‹ˆλ‹€.
- 링크 : https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-0, https://huggingface.co/datasets/maxseats/aihub-464-preprocessed-680GB-set-1
"""

# model_name = "openai/whisper-base"
model_name = "maxseats/SungBeom-whisper-small-ko-set0" # λŒ€μ•ˆ : "SungBeom/whisper-small-ko"
# dataset_name = "maxseats/aihub-464-preprocessed-680GB-set-1"  # 뢈러올 데이터셋(ν—ˆκΉ…νŽ˜μ΄μŠ€ κΈ°μ€€)
dataset_name = "maxseats/aihub-464-preprocessed-680GB-set-1"  # 뢈러올 데이터셋(ν—ˆκΉ…νŽ˜μ΄μŠ€ κΈ°μ€€)

CACHE_DIR = '/mnt/a/maxseats/.finetuning_cache'  # μΊμ‹œ 디렉토리 지정
is_test = False  # True: μ†ŒλŸ‰μ˜ μƒ˜ν”Œ λ°μ΄ν„°λ‘œ ν…ŒμŠ€νŠΈ, False: μ‹€μ œ νŒŒμΈνŠœλ‹

token = "hf_" # ν—ˆκΉ…νŽ˜μ΄μŠ€ 토큰 μž…λ ₯

training_args = Seq2SeqTrainingArguments(
    output_dir=model_dir,  # μ›ν•˜λŠ” 리포지토리 이름을 μž…λ ₯ν•œλ‹€.
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,  # 배치 크기가 2λ°° κ°μ†Œν•  λ•Œλ§ˆλ‹€ 2λ°°μ”© 증가
    learning_rate=1e-5,
    warmup_steps=500,
    # max_steps=2,  # epoch λŒ€μ‹  μ„€μ •
    num_train_epochs=1,     # epoch 수 μ„€μ • / max_steps와 이것 쀑 ν•˜λ‚˜λ§Œ μ„€μ •
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="cer",  # ν•œκ΅­μ–΄μ˜ 경우 'wer'λ³΄λ‹€λŠ” 'cer'이 더 적합할 것
    greater_is_better=False,
    push_to_hub=True,
    save_total_limit=5,           # μ΅œλŒ€ μ €μž₯ν•  λͺ¨λΈ 수 지정
)

#########################################################################################################################################
################################################### μ‚¬μš©μž μ„€μ • λ³€μˆ˜ #####################################################################
#########################################################################################################################################


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # 인풋 데이터와 라벨 λ°μ΄ν„°μ˜ 길이가 λ‹€λ₯΄λ©°, λ”°λΌμ„œ μ„œλ‘œ λ‹€λ₯Έ νŒ¨λ”© 방법이 μ μš©λ˜μ–΄μ•Ό ν•œλ‹€. κ·ΈλŸ¬λ―€λ‘œ 두 데이터λ₯Ό 뢄리해야 ν•œλ‹€.
        # λ¨Όμ € μ˜€λ””μ˜€ 인풋 데이터λ₯Ό κ°„λ‹¨νžˆ ν† μΉ˜ ν…μ„œλ‘œ λ°˜ν™˜ν•˜λŠ” μž‘μ—…μ„ μˆ˜ν–‰ν•œλ‹€.
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Tokenize된 λ ˆμ΄λΈ” μ‹œν€€μŠ€λ₯Ό κ°€μ Έμ˜¨λ‹€.
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # λ ˆμ΄λΈ” μ‹œν€€μŠ€μ— λŒ€ν•΄ μ΅œλŒ€ 길이만큼 νŒ¨λ”© μž‘μ—…μ„ μ‹€μ‹œν•œλ‹€.
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # νŒ¨λ”© 토큰을 -100으둜 μΉ˜ν™˜ν•˜μ—¬ loss 계산 κ³Όμ •μ—μ„œ λ¬΄μ‹œλ˜λ„λ‘ ν•œλ‹€.
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # 이전 ν† ν¬λ‚˜μ΄μ¦ˆ κ³Όμ •μ—μ„œ bos 토큰이 μΆ”κ°€λ˜μ—ˆλ‹€λ©΄ bos 토큰을 μž˜λΌλ‚Έλ‹€.
        # ν•΄λ‹Ή 토큰은 이후 μ–Έμ œλ“  μΆ”κ°€ν•  수 μžˆλ‹€.
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # pad_token을 -100으둜 μΉ˜ν™˜
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # metrics 계산 μ‹œ special token듀을 λΉΌκ³  κ³„μ‚°ν•˜λ„λ‘ μ„€μ •
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    cer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}


# model_dir, ./repo μ΄ˆκΈ°ν™”
if os.path.exists(model_dir):
    shutil.rmtree(model_dir)
os.makedirs(model_dir)

if os.path.exists('./repo'):
    shutil.rmtree('./repo')
os.makedirs('./repo')

# νŒŒμΈνŠœλ‹μ„ μ§„ν–‰ν•˜κ³ μž ν•˜λŠ” λͺ¨λΈμ˜ processor, tokenizer, feature extractor, model λ‘œλ“œ
processor = WhisperProcessor.from_pretrained(model_name, language="Korean", task="transcribe")
tokenizer = WhisperTokenizer.from_pretrained(model_name, language="Korean", task="transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
metric = evaluate.load('cer')
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []


# Hubλ‘œλΆ€ν„° "λͺ¨λ“  μ „μ²˜λ¦¬κ°€ μ™„λ£Œλœ" 데이터셋을 λ‘œλ“œ(이게 μ§„μ§œ μ˜€λž˜κ±Έλ €μš”.)
preprocessed_dataset = load_dataset(dataset_name, cache_dir=CACHE_DIR)

# 30%κΉŒμ§€μ˜ valid 데이터셋 선택(μ½”λ“œ μž‘λ™ ν…ŒμŠ€νŠΈλ₯Ό μœ„ν•¨)
if is_test:
    preprocessed_dataset["valid"] = preprocessed_dataset["valid"].select(range(math.ceil(len(preprocessed_dataset) * 0.3)))

# training_args 객체λ₯Ό JSON ν˜•μ‹μœΌλ‘œ λ³€ν™˜
training_args_dict = training_args.to_dict()

# MLflow UI 관리 폴더 지정
mlflow.set_tracking_uri("sqlite:////content/drive/MyDrive/STT_test/mlflow.db")

# MLflow μ‹€ν—˜ 이름을 λͺ¨λΈ μ΄λ¦„μœΌλ‘œ μ„€μ •
experiment_name = model_name
existing_experiment = mlflow.get_experiment_by_name(experiment_name)

if existing_experiment is not None:
    experiment_id = existing_experiment.experiment_id
else:
    experiment_id = mlflow.create_experiment(experiment_name)


model_version = 1  # λ‘œκΉ… ν•˜λ €λŠ” λͺ¨λΈ 버전(이미 μ‘΄μž¬ν•˜λ©΄, μžλ™ ν• λ‹Ή)

# MLflow λ‘œκΉ…
with mlflow.start_run(experiment_id=experiment_id, description=model_description):
    # training_args λ‘œκΉ…
    for key, value in training_args_dict.items():
        mlflow.log_param(key, value)


    mlflow.set_tag("Dataset", dataset_name) # 데이터셋 λ‘œκΉ…

    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=preprocessed_dataset["train"],
        eval_dataset=preprocessed_dataset["valid"],  # or "test"
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=processor.feature_extractor,
    )

    trainer.train()
    trainer.save_model(model_dir)  # ν•™μŠ΅ ν›„ λͺ¨λΈ μ €μž₯

    # Metric λ‘œκΉ…
    metrics = trainer.evaluate()
    for metric_name, metric_value in metrics.items():
        mlflow.log_metric(metric_name, metric_value)

    # MLflow λͺ¨λΈ λ ˆμ§€μŠ€ν„°
    model_uri = "runs:/{run_id}/{artifact_path}".format(run_id=mlflow.active_run().info.run_id, artifact_path=model_dir)

    # 이 κ°’ μ΄μš©ν•΄μ„œ ν—ˆκΉ…νŽ˜μ΄μŠ€ λͺ¨λΈ 이름 μ„€μ • μ˜ˆμ •
    model_details = mlflow.register_model(model_uri=model_uri, name=model_name.replace('/', '-'))   # λͺ¨λΈ 이름에 '/'λ₯Ό '-'둜 λŒ€μ²΄

    # λͺ¨λΈ Description
    client = MlflowClient()
    client.update_model_version(name=model_details.name, version=model_details.version, description=model_description)
    model_version = model_details.version   # 버전 정보 ν—ˆκΉ…νŽ˜μ΄μŠ€ μ—…λ‘œλ“œ μ‹œ μ‚¬μš©



## ν—ˆκΉ…νŽ˜μ΄μŠ€ 둜그인
while True:

    if token =="exit":
        break

    try:
        result = subprocess.run(["huggingface-cli", "login", "--token", token])
        if result.returncode != 0:
            raise Exception()
        break
    except Exception as e:
        token = input("Please enter your Hugging Face API token: ")


os.environ["HUGGINGFACE_HUB_TOKEN"] = token

# 리포지토리 이름 μ„€μ •
repo_name = "maxseats/" + model_name.replace('/', '-') + '-' + str(model_version)  # ν—ˆκΉ…νŽ˜μ΄μŠ€ λ ˆν¬μ§€ν† λ¦¬ 이름 μ„€μ •

# 리포지토리 생성
create_repo(repo_name, exist_ok=True, token=token)



# 리포지토리 클둠
repo = Repository(local_dir='./repo', clone_from=f"{repo_name}", use_auth_token=token)


# model_dir ν•„μš”ν•œ 파일 볡사
max_depth = 1  # μˆœνšŒν•  μ΅œλŒ€ 깊이

for root, dirs, files in os.walk(model_dir):
    depth = root.count(os.sep) - model_dir.count(os.sep)
    if depth < max_depth:
        for file in files:
            # 파일 경둜 생성
            source_file = os.path.join(root, file)
            # λŒ€μƒ 폴더에 볡사
            shutil.copy(source_file, './repo')


# ν† ν¬λ‚˜μ΄μ € λ‹€μš΄λ‘œλ“œ 및 둜컬 디렉토리에 μ €μž₯
tokenizer.save_pretrained('./repo')


readme = f"""
---
language: ko
tags:
- whisper
- speech-recognition
datasets:
- {dataset_name}
metrics:
- cer
---
# Model Name : {model_name}
# Description
{model_description}
"""


# λͺ¨λΈ μΉ΄λ“œ 및 기타 메타데이터 파일 μž‘μ„±
with open("./repo/README.md", "w") as f:
    f.write(readme)

# 파일 컀밋 ν‘Έμ‹œ
repo.push_to_hub(commit_message="Initial commit")

# 폴더와 ν•˜μœ„ λ‚΄μš© μ‚­μ œ
shutil.rmtree(model_dir)
shutil.rmtree('./repo')