{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tulasiram58827/TTS_TFLite/blob/main/Parallel_WaveGAN_TFLite.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "qu_1y5_ZDxpU" }, "source": [ "This notebook contains code to convert TensorFlow ParallelWaveGAN to TFLite" ] }, { "cell_type": "markdown", "metadata": { "id": "1KQie-EQDzEL" }, "source": [ "## Acknowledgments" ] }, { "cell_type": "markdown", "metadata": { "id": "h-qWgadcDzCW" }, "source": [ "- Pretrained model(in PyTorch) downloaded from [Parallel WaveGAN Repository](https://github.com/kan-bayashi/ParallelWaveGAN#results)\n", "\n", "- Converted PyTorch weights to Tensorflow Compatible using [Tensorflow TTS Repository](https://github.com/TensorSpeech/TensorFlowTTS) with this [Notebook](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/examples/parallel_wavegan/convert_pwgan_from_pytorch_to_tensorflow.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "pBE0GfYwEwoT" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PpGT8_mm8vs-" }, "outputs": [], "source": [ "!git clone https://github.com/TensorSpeech/TensorFlowTTS.git\n", "!cd TensorFlowTTS\n", "!pip install /content/TensorFlowTTS/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2tI8NSz_886Z" }, "outputs": [], "source": [ "!pip install parallel_wavegan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iQq9Gkn9MYT" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import torch\n", "import sys\n", "sys.path.append('/content/TensorFlowTTS')\n", "from tensorflow_tts.models import TFParallelWaveGANGenerator\n", "from tensorflow_tts.configs import ParallelWaveGANGeneratorConfig\n", "\n", "from parallel_wavegan.models import ParallelWaveGANGenerator\n", "import numpy as np\n", "\n", "from IPython.display import Audio" ] }, { "cell_type": "markdown", "metadata": { "id": "BIr9zN74E3PU" }, "source": [ "## Intialize Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "UoFriagU9NBx" }, "outputs": [], "source": [ "tf_model = TFParallelWaveGANGenerator(config=ParallelWaveGANGeneratorConfig(), name=\"parallel_wavegan_generator\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "mayxwoLp9fiR" }, "outputs": [], "source": [ "tf_model._build()" ] }, { "cell_type": "markdown", "metadata": { "id": "P__OyD23E8jN" }, "source": [ "## Load PyTorch Checkpoints" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gZq9ibuzHbI9", "outputId": "660ebfd7-6ed9-49e2-b3b1-9c26894694f4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading...\n", "From: https://drive.google.com/uc?id=1wPwO9K-0Yq-GYcXbHseaqt8kUpa_ojJf\n", "To: /content/checkpoint-400000steps.pkl\n", "\r", "0.00B [00:00, ?B/s]\r", "17.5MB [00:00, 154MB/s]\n" ] } ], "source": [ "!gdown --id 1wPwO9K-0Yq-GYcXbHseaqt8kUpa_ojJf -O checkpoint-400000steps.pkl" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "GoeX-YLQ9kaf" }, "outputs": [], "source": [ "torch_checkpoints = torch.load(\"checkpoint-400000steps.pkl\", map_location=torch.device('cpu'))\n", "torch_generator_weights = torch_checkpoints[\"model\"][\"generator\"]\n", "torch_model = ParallelWaveGANGenerator()\n", "torch_model.load_state_dict(torch_checkpoints[\"model\"][\"generator\"])\n", "torch_model.remove_weight_norm()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3NSfX33w99WW", "outputId": "436460ea-2969-4f89-ba4a-801f0b60abff" }, "outputs": [ { "data": { "text/plain": [ "1334309" ] }, "execution_count": 9, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "model_parameters = filter(lambda p: p.requires_grad, torch_model.parameters())\n", "params = sum([np.prod(p.size()) for p in model_parameters])\n", "params" ] }, { "cell_type": "markdown", "metadata": { "id": "x7t7hPgiE_pR" }, "source": [ "## Convert PyTorch weights to TensorFlow" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "Y4vOfByl-ASZ" }, "outputs": [], "source": [ "# in pytorch, in convolution layer, the order is bias -> weight, in tf it is weight -> bias. We need re-order.\n", "\n", "def convert_weights_pytorch_to_tensorflow(weights_pytorch):\n", " \"\"\"\n", " Convert pytorch Conv1d weight variable to tensorflow Conv2D weights.\n", " 1D: Pytorch (f_output, f_input, kernel_size) -> TF (kernel_size, f_input, 1, f_output)\n", " 2D: Pytorch (f_output, f_input, kernel_size_h, kernel_size_w) -> TF (kernel_size_w, kernel_size_h, f_input, 1, f_output)\n", " \"\"\"\n", " if len(weights_pytorch.shape) == 3: # conv1d-kernel\n", " weights_tensorflow = np.transpose(weights_pytorch, (0,2,1)) # [f_output, kernel_size, f_input]\n", " weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2)) # [kernel-size, f_output, f_input]\n", " weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1)) # [kernel-size, f_input, f_output]\n", " return weights_tensorflow\n", " elif len(weights_pytorch.shape) == 1: # conv1d-bias\n", " return weights_pytorch\n", " elif len(weights_pytorch.shape) == 4: # conv2d-kernel\n", " weights_tensorflow = np.transpose(weights_pytorch, (0,2,1,3)) # [f_output, kernel_size_h, f_input, kernel_size_w]\n", " weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2,3)) # [kernel-size_h, f_output, f_input, kernel-size-w]\n", " weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1,3)) # [kernel_size_h, f_input, f_output, kernel-size-w]\n", " weights_tensorflow = np.transpose(weights_tensorflow, (0,1,3,2)) # [kernel_size_h, f_input, kernel-size-w, f_output]\n", " weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1,3)) # [kernel_size_h, kernel-size-w, f_input, f_output]\n", " weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2,3)) # [kernel-size_w, kernel_size_h, f_input, f_output]\n", " return weights_tensorflow\n", "\n", "torch_weights = []\n", "all_keys = list(torch_model.state_dict().keys())\n", "all_values = list(torch_model.state_dict().values())\n", "\n", "idx_already_append = []\n", "\n", "for i in range(len(all_keys) -1):\n", " if i not in idx_already_append:\n", " if all_keys[i].split(\".\")[0:-1] == all_keys[i + 1].split(\".\")[0:-1]:\n", " if all_keys[i].split(\".\")[-1] == \"bias\" and all_keys[i + 1].split(\".\")[-1] == \"weight\":\n", " torch_weights.append(convert_weights_pytorch_to_tensorflow(all_values[i + 1].cpu().detach().numpy()))\n", " torch_weights.append(convert_weights_pytorch_to_tensorflow(all_values[i].cpu().detach().numpy()))\n", " idx_already_append.append(i)\n", " idx_already_append.append(i + 1)\n", " else:\n", " if i not in idx_already_append:\n", " torch_weights.append(convert_weights_pytorch_to_tensorflow(all_values[i].cpu().detach().numpy()))\n", " idx_already_append.append(i)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "168kydzc-SxJ" }, "outputs": [], "source": [ "tf_var = tf_model.trainable_variables\n", "for i, var in enumerate(tf_var):\n", " tf.keras.backend.set_value(var, torch_weights[i])" ] }, { "cell_type": "markdown", "metadata": { "id": "p8D70bCeFOAA" }, "source": [ "## Convert to TFLite" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "mTnGufeuH3io" }, "outputs": [], "source": [ "def convert_to_tflite(quantization):\n", " pwg_concrete_function = tf_model.inference.get_concrete_function()\n", " converter = tf.lite.TFLiteConverter.from_concrete_functions([pwg_concrete_function])\n", " converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", " converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS]\n", " if quantization == 'float16':\n", " converter.target_spec.supported_types = [tf.float16]\n", " tf_lite_model = converter.convert()\n", " model_name = f'parallel_wavegan_{quantization}.tflite'\n", " with open(model_name, 'wb') as f:\n", " f.write(tf_lite_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "zSaic3flIJX7" }, "source": [ "#### Dynamic Range Quantization" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "6STZKNqg-vxS" }, "outputs": [], "source": [ "quantization = 'dr' #@param [\"dr\", \"float16\"]\n", "convert_to_tflite(quantization)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VB5H4bUmIUFR", "outputId": "e53b77e9-d680-424a-a1e7-5a6abb723866" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5.7M\tparallel_wavegan_dr.tflite\n" ] } ], "source": [ "!du -sh parallel_wavegan_dr.tflite" ] }, { "cell_type": "markdown", "metadata": { "id": "tb_FF8fNINWr" }, "source": [ "#### Float16 Quantization" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "19kqBUnQ_KG3", "outputId": "4ab46a0f-98cb-44ea-8b8c-a337b5fc8d35" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.2M\tparallel_wavegan_float16.tflite\n" ] } ], "source": [ "quantization = 'float16'\n", "convert_to_tflite(quantization)\n", "!du -sh parallel_wavegan_float16.tflite" ] }, { "cell_type": "markdown", "metadata": { "id": "7Kab76pmFifJ" }, "source": [ "## Download Sample Output of Tacotron2" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IwuNQ_Z1Fm0d", "outputId": "a5e18dc0-573d-468e-f215-87d5bc86e67e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading...\n", "From: https://drive.google.com/uc?id=1LmU3j8yedwBzXKVDo9tCvozLM4iwkRnP\n", "To: /content/tac_output.npy\n", "\r", " 0% 0.00/36.0k [00:00\n", " \n", " Your browser does not support the audio element.\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 31, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "output = output[0, :, 0]\n", "\n", "Audio(output, rate=22050)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Tensorflow_TTS_PWGAN.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 1 }