File size: 22,459 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 |
{
"cells": [
{
"cell_type": "markdown",
"id": "52580f1b",
"metadata": {},
"source": [
"# Example: Training Esperanto ASR model using Mozilla Common Voice Dataset\n",
"\n",
"\n",
"Training an ASR model for a new language can be challenging, especially for low-resource languages (see [example](https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/examples/kinyarwanda_asr.rst) for Kinyarwanda CommonVoice ASR model).\n",
"\n",
"This example describes all basic steps required to build ASR model for Esperanto:\n",
"\n",
"* Data preparation\n",
"* Tokenization\n",
"* Training hyper-parameters\n",
"* Training from scratch\n",
"* Fine-tuning from pretrained models on other languages (English, Spanish, Italian).\n",
"* Fine-tuning from pretrained English SSL ([Self-supervised learning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/ssl/intro.html?highlight=self%20supervised)) model\n",
"* Model evaluation"
]
},
{
"cell_type": "markdown",
"id": "ff84f5fc",
"metadata": {},
"source": [
"## Data Preparation\n",
"\n",
"Mozilla Common Voice provides 1400 hours of validated Esperanto speech (see [here](https://arxiv.org/abs/1912.0667)). However, the final training dataset consists only of 250 hours because the train, test, and development sets are bucketed such that any given speaker may appear in only one. This ensures that contributors seen at train time are not seen at test time, which would skew results. Additionally, repetitions of text sentences are removed from the train, test, and development sets of the corpus”. \n",
"\n",
"### Downloading the Data\n",
"\n",
"You can use the NeMo script to download MCV dataset from Hugging Face and get NeMo data manifests for Esperanto:\n",
"\n",
"# Setup\n",
"After installation of huggingface datasets (`pip install datasets`), CommonVoice requires authentication.\n",
"\n",
"Website steps:\n",
"- Visit https://huggingface.co/settings/profile\n",
"- Visit \"Access Tokens\" on list of items.\n",
"- Create new token - provide a name for the token and \"read\" access is sufficient.\n",
" - PRESERVE THAT TOKEN API KEY. You can copy that key for next step.\n",
"- Visit the HuggingFace Dataset page for Mozilla Common Voice\n",
" - There should be a section that asks you for your approval.\n",
" - Make sure you are logged in and then read that agreement.\n",
" - If and only if you agree to the text, then accept the terms.\n",
"\n",
"Code steps:\n",
"- Now on your machine, run `huggingface-cli login`\n",
"- Paste your preserved HF TOKEN API KEY (from above).\n",
"\n",
"Once the above is complete, to download the data:\n",
"\n",
"```bash\n",
"python ${NEMO_ROOT}/scripts/speech_recognition/convert_hf_dataset_to_nemo.py \\\n",
" output_dir=${OUTPUT_DIR} \\\n",
" path=\"mozilla-foundation/common_voice_11_0\" \\\n",
" name=\"eo\" \\\n",
" ensure_ascii=False \\\n",
" use_auth_token=True\n",
"```\n",
"You will get the next data structure:\n",
"\n",
"```bash\n",
" .\n",
" └── mozilla-foundation\n",
" └── common_voice_11_0\n",
" └── eo\n",
" ├── test\n",
" ├── train\n",
" └── validation\n",
"```\n",
"\n",
"### Dataset Preprocessing\n",
"\n",
"Next, we must clear the text data from punctuation and various “garbage” characters. In addition to deleting a standard set of elements (as in Kinyarwanda), you can compute the frequency of characters in the train set and add the rarest (occurring less than ten times) to the list for deletion. \n",
"\n",
"\n",
"```python\n",
"import json\n",
"from tqdm import tqdm\n",
"\n",
"dev_manifest = f\"{YOUR_DATA_ROOT}/validation/validation_mozilla-foundation_common_voice_11_0_manifest.json\"\n",
"test_manifest = f\"{YOUR_DATA_ROOT}/test/test_mozilla-foundation_common_voice_11_0_manifest.json\"\n",
"train_manifest = f\"{YOUR_DATA_ROOT}/train/train_mozilla-foundation_common_voice_11_0_manifest.json\"\n",
"\n",
"def compute_char_counts(manifest):\n",
" char_counts = {}\n",
" with open(manifest, 'r') as fn_in:\n",
" for line in tqdm(fn_in, desc=\"Compute counts..\"):\n",
" line = line.replace(\"\\n\", \"\")\n",
" data = json.loads(line)\n",
" text = data[\"text\"]\n",
" for word in text.split():\n",
" for char in word:\n",
" if char not in char_counts:\n",
" char_counts[char] = 1\n",
" else:\n",
" char_counts[char] += 1\n",
" return char_counts\n",
"\n",
"char_counts = compute_char_counts(train_manifest)\n",
"\n",
"threshold = 10\n",
"trash_char_list = []\n",
"\n",
"for char in char_counts:\n",
" if char_counts[char] <= threshold:\n",
" trash_char_list.append(char)\n",
"```\n",
"\n",
"Let's check:\n",
"\n",
"```python\n",
"print(trash_char_list)\n",
"['é', 'ǔ', 'á', '¨', 'Ŭ', 'fi', '=', 'y', '`', 'q', 'ü', '♫', '‑', 'x', '¸', 'ʼ', '‹', '›', 'ñ']\n",
"```\n",
"\n",
"Next we will check the data for anomalies in audio file (for example, audio file with noise only). For this end, we check character rate (number of chars per second). For example, If the char rate is too high (more than 15 chars per second), then something is wrong with the audio file. It is better to filter such data from the training dataset in advance. Other problematic files can be filtered out after receiving the first trained model. We will consider this method at the end of our example.\n",
"\n",
"```python\n",
"import re\n",
"import json\n",
"from tqdm import tqdm\n",
"\n",
"def clear_data_set(manifest, char_rate_threshold=None):\n",
"\n",
" chars_to_ignore_regex = \"[\\.\\,\\?\\:\\-!;()«»…\\]\\[/\\*–‽+&_\\\\½√>€™$•¼}{~—=“\\\"”″‟„]\"\n",
" addition_ignore_regex = f\"[{''.join(trash_char_list)}]\"\n",
"\n",
" manifest_clean = manifest + '.clean'\n",
" war_count = 0\n",
" with open(manifest, 'r') as fn_in, \\\n",
" open(manifest_clean, 'w', encoding='utf-8') as fn_out:\n",
" for line in tqdm(fn_in, desc=\"Cleaning manifest data\"):\n",
" line = line.replace(\"\\n\", \"\")\n",
" data = json.loads(line)\n",
" text = data[\"text\"]\n",
" if char_rate_threshold and len(text.replace(' ', '')) / float(data['duration']) > char_rate_threshold:\n",
" print(f\"[WARNING]: {data['audio_filepath']} has char rate > 15 per sec: {len(text)} chars, {data['duration']} duration\")\n",
" war_count += 1\n",
" continue\n",
" text = re.sub(chars_to_ignore_regex, \"\", text)\n",
" text = re.sub(addition_ignore_regex, \"\", text)\n",
" data[\"text\"] = text.lower()\n",
" data = json.dumps(data, ensure_ascii=False)\n",
" fn_out.write(f\"{data}\\n\")\n",
" print(f\"[INFO]: {war_count} files were removed from manifest\")\n",
"\n",
"clear_data_set(dev_manifest)\n",
"clear_data_set(test_manifest)\n",
"clear_data_set(train_manifest, char_rate_threshold=15)\n",
"```\n",
"\n",
"### Creating the Tarred Training Dataset\n",
"\n",
"The tarred dataset allows storing the dataset as large *.tar files instead of small separate audio files. It may speed up the training and minimizes the load when data is moved from storage to GPU nodes.\n",
"\n",
"The NeMo toolkit provides a [script]( https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py) to get tarred dataset.\n",
"\n",
"```bash\n",
"\n",
"TRAIN_MANIFEST=${YOUR_DATA_ROOT}/train/train_mozilla-foundation_common_voice_11_0_manifest.json.clean\n",
"\n",
"python ${NEMO_ROOT}/scripts/speech_recognition/convert_to_tarred_audio_dataset.py \\\n",
" --manifest_path=${TRAIN_MANIFEST} \\\n",
" --target_dir=${YOUR_DATA_ROOT}/train_tarred_1bk \\\n",
" --num_shards=1024 \\\n",
" --max_duration=15.0 \\\n",
" --min_duration=1.0 \\\n",
" --shuffle \\\n",
" --shuffle_seed=1 \\\n",
" --sort_in_shards \\\n",
" --workers=-1\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "c15bc180",
"metadata": {},
"source": [
"## Text Tokenization\n",
"\n",
"We use the standard [Byte-pair](https://en.wikipedia.org/wiki/Byte_pair_encoding) encoding algorithm with 128, 512, and 1024 vocabulary size. We found that 128 works best for relatively small Esperanto dataset (~250 hours). For larger datasets, one can get better results with larger vocabulary size (512…1024 BPE tokens).\n",
"\n",
"```\n",
"VOCAB_SIZE=128\n",
"\n",
"python ${NEMO_ROOT}/scripts/tokenizers/process_asr_text_tokenizer.py \\\n",
" --manifest=${TRAIN_MANIFEST} \\\n",
" --vocab_size=${VOCAB_SIZE} \\\n",
" --data_root=${YOUR_DATA_ROOT}/esperanto/tokenizers \\\n",
" --tokenizer=\"spe\" \\\n",
" --spe_type=bpe \n",
"```"
]
},
{
"cell_type": "markdown",
"id": "106ec1dc",
"metadata": {},
"source": [
"## Training hyper-parameters\n",
"\n",
"The training parameters are defined in the [config file](https://github.com/NVIDIA/NeMo/blob/main/examples/asr/conf/conformer/conformer_ctc_bpe.yaml) (general description of the [ASR configuration file](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/configs.html)). As an encoder, the [Conformer model](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-ctc) is used here, the training parameters for which are already well configured based on the training English models. However, the set of optimal parameters may differ for a new language. In this section, we will look at the set of simple parameters that can improve recognition quality for a new language without digging into the details of the Conformer model too much.\n",
"\n",
"### Select Training Batch Size\n",
"\n",
"We trained model on server with 16 V100 GPUs with 32 GB. We use a local batch size = 32 per GPU V100), so global batch size is 32x16=512. In general, we observed, that global batch between 512 and 2048 works well for Conformer-CTC-Large model. One can use the [accumulate_grad_batches](https://github.com/NVIDIA/NeMo/blob/main/examples/asr/conf/conformer/conformer_ctc_bpe.yaml#L173) parameter to increase the size of the global batch, which is equal to *local_batch * num_gpu * accumulate_grad_batches*.\n",
"\n",
"### Selecting Optimizer and Learning Rate Scheduler\n",
"\n",
"The model was trained with AdamW optimizer and CosineAnealing Learning Rate (LR) scheduler. We use Learning Rate warmup when LR goes from 0 to maximum LR to stabilize initial phase of training. The number of warmup steps determines how quickly the scheduler will reach the peak learning rate during model training. The recommended number of steps is approximately 10-20% of total training duration. We used 8,000-10,000 warmup steps.\n",
"\n",
"Now we can plot our learning rate for CosineAnnealing schedule:\n",
"\n",
"```python\n",
"import nemo\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# params:\n",
"train_files_num = 144000 # number of training audio_files\n",
"global_batch_size = 1024 # local_batch * gpu_num * accum_gradient\n",
"num_epoch = 300\n",
"warmup_steps = 10000\n",
"config_learning_rate = 1e-3\n",
"\n",
"steps_num = int(train_files_num / global_batch_size * num_epoch)\n",
"print(f\"steps number is: {steps_num}\")\n",
"\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=config_learning_rate)\n",
"scheduler = nemo.core.optim.lr_scheduler.CosineAnnealing(optimizer,\n",
" max_steps=steps_num,\n",
" warmup_steps=warmup_steps,\n",
" min_lr=1e-6)\n",
"lrs = []\n",
"\n",
"for i in range(steps_num):\n",
" optimizer.step()\n",
" lr = optimizer.param_groups[0][\"lr\"]\n",
" lrs.append(lr)\n",
" scheduler.step()\n",
"\n",
"plt.plot(lrs)\n",
"```\n",
"\n",
"<img src=\"./images/CosineAnnealing_scheduler.png\" alt=\"NeMo CosineAnnealing scheduler\" style=\"width: 400px;\"/>\n",
"\n",
"### Numerical Precision\n",
"\n",
"By default, it is recommended to use half-precision float (FP16 for V100 and BF16 for A100 GPU) to speed up the training process. However, training with half-precision may affect the convergence of the model, for example training loss can explode. In this case, we recommend to decrease LR or switch to float32. "
]
},
{
"cell_type": "markdown",
"id": "124aefb8",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We use three main scenarios to train Espearnto ASR model:\n",
"\n",
"* Training from scratch.\n",
"* Fine-tuning from ASR models for other languages (English, Spanish, Italian).\n",
"* Fine-tuning from an English SSL ([Self-supervised learning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/ssl/intro.html?highlight=self%20supervised)) model.\n",
"\n",
"For the training of the [Conformer-CTC](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-ctc) model, we use [speech_to_text_ctc_bpe.py](https://github.com/NVIDIA/NeMo/tree/stable/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py) with the default config [conformer_ctc_bpe.yaml](https://github.com/NVIDIA/NeMo/tree/stable/examples/asr/conf/conformer/conformer_ctc_bpe.yaml). Here you can see the example of how to run this training:\n",
"\n",
"```bash\n",
"TOKENIZER=${YOUR_DATA_ROOT}/esperanto/tokenizers/tokenizer_spe_bpe_v128\n",
"TRAIN_MANIFEST=${YOUR_DATA_ROOT}/train_tarred_1bk/tarred_audio_manifest.json\n",
"TARRED_AUDIO_FILEPATHS=${YOUR_DATA_ROOT}/train_tarred_1bk/audio__OP_0..1023_CL_.tar # \"_OP_0..1023_CL_\" is the range for the batch of files audio_0.tar, audio_1.tar, ..., audio_1023.tar\n",
"DEV_MANIFEST=${YOUR_DATA_ROOT}/validation/validation_mozilla-foundation_common_voice_11_0_manifest.json.clean\n",
"TEST_MANIFEST=${YOUR_DATA_ROOT}/test/test_mozilla-foundation_common_voice_11_0_manifest.json.clean\n",
"\n",
"python ${NEMO_ROOT}/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py \\\n",
" --config-path=../conf/conformer/ \\\n",
" --config-name=conformer_ctc_bpe \\\n",
" exp_manager.name=\"Name of our experiment\" \\\n",
" exp_manager.resume_if_exists=true \\\n",
" exp_manager.resume_ignore_no_checkpoint=true \\\n",
" exp_manager.exp_dir=results/ \\\n",
" ++model.encoder.conv_norm_type=layer_norm \\\n",
" model.tokenizer.dir=$TOKENIZER \\\n",
" model.train_ds.is_tarred=true \\\n",
" model.train_ds.tarred_audio_filepaths=$TARRED_AUDIO_FILEPATHS \\\n",
" model.train_ds.manifest_filepath=$TRAIN_MANIFEST \\\n",
" model.validation_ds.manifest_filepath=$DEV_MANIFEST \\\n",
" model.test_ds.manifest_filepath=$TEST_MANIFEST\n",
"```\n",
"\n",
"Main training parameters:\n",
"\n",
"* Tokenization: BPE 128/512/1024\n",
"* Model: Conformer-CTC-large with Layer Normalization\n",
"* Optimizer: AdamW, weight_decay 1e-3, LR 1e-3\n",
"* Scheduler: CosineAnnealing, warmup_steps 10000, min_lr 1e-6\n",
"* Batch: 32 local, 1024 global (2 grad accumulation)\n",
"* Precision: FP16\n",
"* GPUs: 16 V100\n",
"\n",
"The following table provides the results for training Esperanto Conformer-CTC-large model from scratch with different BPE vocabulary size.\n",
"\n",
"BPE Size | DEV WER (%) | TEST WER (%)\n",
"-----|-----|----- \n",
"128|**3.96**|**6.48**\n",
"512|4.62|7.31\n",
"1024|5.81|8.56\n",
"\n",
"BPE vocabulary with 128 size provides the lowest WER since our training dataset is l (~250 hours) is insufficient to small to train models with larger BPE vocabulary sizes.\n",
"\n",
"For fine-tuning from already trained ASR models, we use three different models:\n",
"\n",
"* English [stt_en_conformer_ctc_large](https://huggingface.co/nvidia/stt_en_conformer_ctc_large) (several thousand hours of English speech).\n",
"* Spanish [stt_es_conformer_ctc_large](https://huggingface.co/nvidia/stt_es_conformer_ctc_large) (1340 hours of Spanish speech).\n",
"* Italian [stt_it_conformer_ctc_large](https://huggingface.co/nvidia/stt_it_conformer_ctc_large) (487 hours of Italian speech).\n",
"\n",
"To finetune a model with the same vocabulary size, just set the desired model via the *init_from_pretrained_model* parameter:\n",
"\n",
"```yaml\n",
"+init_from_pretrained_model=${PRETRAINED_MODEL_NAME}\n",
"```\n",
"\n",
"If the size of the vocabulary differs from the one presented in the pretrained model, you need to change the vocabulary manually as done in the [finetuning tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb).\n",
"\n",
"```python\n",
"model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(f\"nvidia/{PRETRAINED_MODEL_NAME}\", map_location='cpu')\n",
"model.change_vocabulary(new_tokenizer_dir=TOKENIZER, new_tokenizer_type=\"bpe\")\n",
"model.encoder.unfreeze()\n",
"model.save_to(f\"{save_path}\")\n",
"```\n",
"\n",
"There is no need to change anything for the SSL model, it will replace the vocabulary itself. However, you will need to first download this model and set it through another parameter *init_from_nemo_model*:\n",
"\n",
"```yaml\n",
"++init_from_nemo_model=${PRETRAINED_MODEL} \\\n",
"```\n",
"\n",
"As the SSL model, we use [ssl_en_conformer_large](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/ssl_en_conformer_large) which is trained using LibriLight corpus (~56k hrs of unlabeled English speech).\n",
"All models for fine-tuning are available on [Nvidia Hugging Face](https://huggingface.co/nvidia) or [NGC](https://catalog.ngc.nvidia.com/models) repo.\n",
"\n",
"The following table shows all results for fine-tuning from pretrained models for the Conformer-CTC-large model and compares them with the model that was obtained by training from scratch (here we use BPE size 128 for all the models because it gives the best results).\n",
"\n",
"\n",
"Training Mode | DEV WER (%) | TEST WER (%)\n",
"-----|-----|----- \n",
"From scratch|3.96|6.48\n",
"Finetuning (English)|3.45|5.45\n",
"Finetuning (Spanish)|3.40|5.52\n",
"Finetuning (Italian)|3.29|5.36\n",
"Finetuning (SSL English)|**2.90**|**4.76**\n",
"\n",
"We can also monitor test WER behavior during training process using wandb plots (X - global step, Y - test WER):\n",
"\n",
"\n",
"\n",
"As you can see, the best way to get the Esperanto ASR model (the model can be found on [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_eo_conformer_ctc_large) and [Hugging Face](https://huggingface.co/nvidia/stt_eo_conformer_ctc_large) is finetuning from the pretrained SSL model for English."
]
},
{
"cell_type": "markdown",
"id": "1e3abb4d",
"metadata": {},
"source": [
"## Decoding\n",
"\n",
"At the end of the training, several checkpoints (usually 5) and the best model (not always from the latest epoch) are stored in the model folder. Checkpoint averaging (script) can help to improve the final decoding accuracy. In our case, this did not improve the CTC models. However, it was possible to get an improvement in the range of 0.1-0.2% WER for some RNNT models. To make averaging, use the following command:\n",
"\n",
"```bash\n",
"python ${NEMO_ROOT}/scripts/checkpoint_averaging/checkpoint_averaging.py <your_trained_model.nemo>\n",
"```\n",
"\n",
"For decoding you can use:\n",
"\n",
"```bash\n",
"python ${NEMO_ROOT}/examples/asr/speech_to_text_eval.py \\\n",
" model_path=${MODEL} \\\n",
" pretrained_name=null \\\n",
" dataset_manifest=${TEST_MANIFEST} \\\n",
" batch_size=${BATCH_SIZE} \\\n",
" output_filename=${OUTPUT_MANIFEST} \\\n",
" amp=False \\\n",
" use_cer=False)\n",
"```\n",
"\n",
"You can use the Speech Data Explorer to analyze recognition errors, similar to the Kinyarwanda example.\n",
"We listened to files with an anomaly high WER (>50%) and found many problematic files. They have wrong transcriptions and cut or empty audio files in the dev and test sets.\n",
"\n",
"```bash\n",
"python ${NEMO_ROOT}/tools/speech_data_explorer/data_explorer.py <your_decoded_manifest_file>\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "48e37bf3",
"metadata": {},
"source": [
"## Training data analysis\n",
"\n",
"For an additional analysis of the training dataset, you can decode it using an already trained model. Train examples with a high error rate (WER > 50%) are likely to be problematic files. Removing them from the training set is preferred because a model can train text even for almost empty audio. We do not want this behavior from the ASR model."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|