{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],\n", " num_rows: 22263\n", " })\n", " test: Dataset({\n", " features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],\n", " num_rows: 457\n", " })\n", "})\n" ] } ], "source": [ "from datasets import load_dataset, DatasetDict, DownloadMode\n", "\n", "zeroth_korean = DatasetDict()\n", "\n", "zeroth_korean[\"train\"] = load_dataset(\"kresnik/zeroth_korean\", split=\"train\", download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)\n", "zeroth_korean[\"test\"] = load_dataset(\"kresnik/zeroth_korean\", split=\"test\", download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)\n", "\n", "print(zeroth_korean)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['audio', 'text'],\n", " num_rows: 22263\n", " })\n", " test: Dataset({\n", " features: ['audio', 'text'],\n", " num_rows: 457\n", " })\n", "})\n" ] } ], "source": [ "zeroth_korean = zeroth_korean.remove_columns([\"file\", \"speaker_id\", \"chapter_id\", \"id\"])\n", "\n", "print(zeroth_korean)\n", "# audio랑 text 남기고\n", "# file, speaker_id, chapter_id, id 지우기" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperFeatureExtractor\n", "\n", "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-small\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "from transformers import WhisperTokenizer\n", "\n", "tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-small\", language=\"Korean\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "from transformers import WhisperProcessor\n", "\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", language=\"Korean\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'audio': {'path': '/home/kwon/.cache/huggingface/datasets/downloads/extracted/38285576d48be93072ee02a2cfd2b30b3c4b32eb7022f4c421e5879a9bda428f/train_data_01/003/124/124_003_0009.flac', 'array': array([-3.05175781e-05, -6.10351562e-05, -3.05175781e-05, ...,\n", " -1.22070312e-04, 1.83105469e-04, -1.22070312e-04]), 'sampling_rate': 16000}, 'text': '이 과정에서 아파트 명의 문제 등으로 말다툼이 벌어졌고 감정이 격해진 송씨가 공구를 침대에 펼쳐놓고 흉기를 아내 가슴에 들이대며 죽여버린다고 협박했다'}\n" ] } ], "source": [ "print(zeroth_korean[\"train\"][0])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "\n", "from datasets import Audio\n", "\n", "zeroth_korean = zeroth_korean.cast_column(\"audio\", Audio(sampling_rate=16000))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'audio': {'path': '/home/kwon/.cache/huggingface/datasets/downloads/extracted/38285576d48be93072ee02a2cfd2b30b3c4b32eb7022f4c421e5879a9bda428f/train_data_01/003/124/124_003_0009.flac', 'array': array([-3.05175781e-05, -6.10351562e-05, -3.05175781e-05, ...,\n", " -1.22070312e-04, 1.83105469e-04, -1.22070312e-04]), 'sampling_rate': 16000}, 'text': '이 과정에서 아파트 명의 문제 등으로 말다툼이 벌어졌고 감정이 격해진 송씨가 공구를 침대에 펼쳐놓고 흉기를 아내 가슴에 들이대며 죽여버린다고 협박했다'}\n" ] } ], "source": [ "print(zeroth_korean[\"train\"][0])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n", "\n", "do_lower_case = False\n", "do_remove_punctuation = False\n", "\n", "normalizer = BasicTextNormalizer()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " # load and (possibly) resample audio data to 16kHz\n", " audio = batch[\"audio\"]\n", "\n", " # compute log-Mel input features from input audio array \n", " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " # compute input length of audio sample in seconds\n", " batch[\"input_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n", " \n", " # optional pre-processing steps\n", " transcription = batch[\"text\"]\n", " if do_lower_case:\n", " transcription = transcription.lower()\n", " if do_remove_punctuation:\n", " transcription = normalizer(transcription).strip()\n", " \n", " # encode target text to label ids\n", " batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0c3d689f889f4a71a6ffe0350727d7d7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map (num_proc=2): 0%| | 0/22263 [00:00 Dict[str, torch.Tensor]:\n", " # split inputs and labels since they have to be of different lengths and need different padding methods\n", " # first treat the audio inputs by simply returning torch tensors\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", "\n", " # get the tokenized label sequences\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " # pad the labels to max length\n", " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " # if bos token is appended in previous tokenization step,\n", " # cut bos token here as it's append later anyways\n", " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "wer_score = evaluate.load(\"wer\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from evaluate import load\n", "cer_score = evaluate.load(\"cer\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이게 원본 WER 평가 코드\n", "\n", "\n", "# evaluate with the 'normalised' WER\n", "do_normalize_eval = True\n", "\n", "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " # replace -100 with the pad_token_id\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " # we do not want to group tokens when computing the metrics\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " if do_normalize_eval:\n", " pred_str = [normalizer(pred) for pred in pred_str]\n", " label_str = [normalizer(label) for label in label_str]\n", "\n", " wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# evaluate with the 'normalised' WER\n", "do_normalize_eval = True\n", "\n", "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " # replace -100 with the pad_token_id\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " # we do not want to group tokens when computing the metrics\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " if do_normalize_eval:\n", " pred_str = [normalizer(pred) for pred in pred_str]\n", " label_str = [normalizer(label) for label in label_str]\n", "\n", " wer = 100 * wer_score.compute(predictions=pred_str, references=label_str)\n", " cer = 100 * cer_score.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer, \"cer\": cer}" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperForConditionalGeneration\n", "\n", "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-small\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "model.config.forced_decoder_ids = None\n", "model.config.suppress_tokens = []\n", "model.config.use_cache = False" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./\",\n", " per_device_train_batch_size=16,\n", " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n", " learning_rate=1e-5,\n", " warmup_steps=500,\n", " max_steps=4000,\n", " gradient_checkpointing=True,\n", " fp16=True,\n", " evaluation_strategy=\"steps\",\n", " per_device_eval_batch_size=8,\n", " predict_with_generate=True,\n", " generation_max_length=225,\n", " save_steps=1000,\n", " eval_steps=1000,\n", " logging_steps=25,\n", " report_to=[\"tensorboard\"],\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"cer\",\n", " greater_is_better=False,\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=zeroth_korean[\"train\"],\n", " eval_dataset=zeroth_korean[\"test\"],\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor.feature_extractor,\n", ")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "processor.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e65ce3bc72904c3999b034bfeb8a12b6", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/4000 [00:00