{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "0aWBuYb2UDIO" }, "source": [ "# SAM: Animation Inference Playground" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Uuviq3qQkUFy" }, "outputs": [], "source": [ "import os\n", "os.chdir('/content')\n", "CODE_DIR = 'SAM'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QQ6XEmlHlXbk", "outputId": "2d2af9bb-1bbe-4946-84db-dbeaf62aa226" }, "outputs": [], "source": [ "!git clone https://github.com/yuval-alaluf/SAM.git $CODE_DIR" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JaRUFuVHkzye", "outputId": "12fc6dcd-951b-472f-b9d6-0a09f749e931" }, "outputs": [], "source": [ "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n", "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n", "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "23baccYQlU9E" }, "outputs": [], "source": [ "os.chdir(f'./{CODE_DIR}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lEWIzkaLSsFY" }, "outputs": [], "source": [ "from argparse import Namespace\n", "import os\n", "import sys\n", "import pprint\n", "import numpy as np\n", "from PIL import Image\n", "import torch\n", "import torchvision.transforms as transforms\n", "\n", "sys.path.append(\".\")\n", "sys.path.append(\"..\")\n", "\n", "from datasets.augmentations import AgeTransformer\n", "from utils.common import tensor2im\n", "from models.psp import pSp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9J3NEVlESsFl" }, "outputs": [], "source": [ "EXPERIMENT_TYPE = 'ffhq_aging'" ] }, { "cell_type": "markdown", "metadata": { "id": "eFjfO9q9SsFm" }, "source": [ "## Step 1: Download Pretrained Model\n", "As part of this repository, we provide our pretrained aging model.\n", "We'll download the model for the selected experiments as save it to the folder `../pretrained_models`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e2L9GRCFSsFm" }, "outputs": [], "source": [ "def get_download_model_command(file_id, file_name):\n", " \"\"\" Get wget download command for downloading the desired model and save to directory ../pretrained_models. \"\"\"\n", " current_directory = os.getcwd()\n", " save_path = os.path.join(os.path.dirname(current_directory), \"pretrained_models\")\n", " if not os.path.exists(save_path):\n", " os.makedirs(save_path)\n", " url = r\"\"\"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILE_ID}\" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt\"\"\".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)\n", " return url" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yotXNk8PSsFn" }, "outputs": [], "source": [ "MODEL_PATHS = {\n", " \"ffhq_aging\": {\"id\": \"1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC\", \"name\": \"sam_ffhq_aging.pt\"}\n", "}\n", "\n", "path = MODEL_PATHS[EXPERIMENT_TYPE]\n", "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tXUqcxd8SsFo", "outputId": "0e2426ec-f96f-44cb-d1a3-736807bc4f37" }, "outputs": [], "source": [ "!wget {download_command}" ] }, { "cell_type": "markdown", "metadata": { "id": "mismxuEvSsFp" }, "source": [ "## Step 3: Define Inference Parameters" ] }, { "cell_type": "markdown", "metadata": { "id": "NRzP-untSsFq" }, "source": [ "Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the\n", "image to perform inference on.\n", "While we provide default values to run this script, feel free to change as needed." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pHeJRfsnSsFq" }, "outputs": [], "source": [ "EXPERIMENT_DATA_ARGS = {\n", " \"ffhq_aging\": {\n", " \"model_path\": \"../pretrained_models/sam_ffhq_aging.pt\",\n", " \"transform\": transforms.Compose([\n", " transforms.Resize((256, 256)),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IgzLA96mSsFq" }, "outputs": [], "source": [ "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]" ] }, { "cell_type": "markdown", "metadata": { "id": "kqlL9U5uSsFr" }, "source": [ "## Step 4: Load Pretrained Model\n", "We assume that you have downloaded the pretrained aging model and placed it in the path defined above." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8khx-2fcSsFr" }, "outputs": [], "source": [ "model_path = EXPERIMENT_ARGS['model_path']\n", "ckpt = torch.load(model_path, map_location='cpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fVfLlsfjSsFr", "outputId": "ccc5cc29-e59d-414c-a216-fa967ece4eb9" }, "outputs": [], "source": [ "opts = ckpt['opts']\n", "pprint.pprint(opts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vxYxSJk1SsFs" }, "outputs": [], "source": [ "# update the training options\n", "opts['checkpoint_path'] = model_path" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Z4LNoYt_SsFs", "outputId": "030da871-e4a2-42a9-9c4b-8b96e028fd40" }, "outputs": [], "source": [ "opts = Namespace(**opts)\n", "net = pSp(opts)\n", "net.eval()\n", "net.cuda()\n", "print('Model successfully loaded!')" ] }, { "cell_type": "markdown", "metadata": { "id": "3mNH95EJSsFs" }, "source": [ "### Utils for Generating MP4 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GASwMVmJSsFs" }, "outputs": [], "source": [ "import imageio\n", "from tqdm import tqdm\n", "import matplotlib\n", "from IPython.display import HTML\n", "from base64 import b64encode\n", "\n", "matplotlib.use('module://ipykernel.pylab.backend_inline')\n", "%matplotlib inline\n", "\n", "\n", "def generate_mp4(out_name, images, kwargs):\n", " writer = imageio.get_writer(out_name + '.mp4', **kwargs)\n", " for image in images:\n", " writer.append_data(image)\n", " writer.close()\n", "\n", "\n", "def run_on_batch_to_vecs(inputs, net):\n", " _, result_batch = net(inputs.to(\"cuda\").float(), return_latents=True, randomize_noise=False, resize=False)\n", " return result_batch.cpu()\n", "\n", "\n", "def get_result_from_vecs(vectors_a, vectors_b, alpha):\n", " results = []\n", " for i in range(len(vectors_a)):\n", " cur_vec = vectors_b[i] * alpha + vectors_a[i] * (1 - alpha)\n", " res = net(cur_vec.cuda(), randomize_noise=False, input_code=True, input_is_full=True, resize=False)\n", " results.append(res[0])\n", " return results\n", "\n", "\n", "def show_mp4(filename, width=400):\n", " mp4 = open(filename + '.mp4', 'rb').read()\n", " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", " display(HTML(\"\"\"\n", " \n", " \"\"\" % (width, data_url)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "y4oqFBfTSsFt", "outputId": "6ff2e285-e5e0-4d9c-e16d-af43eee1e8f1" }, "outputs": [], "source": [ "SEED = 42\n", "np.random.seed(SEED)\n", "\n", "img_transforms = EXPERIMENT_ARGS['transform']\n", "n_transition = 25\n", "kwargs = {'fps': 40}\n", "save_path = \"notebooks/animations\"\n", "os.makedirs(save_path, exist_ok=True)\n", "\n", "#################################################################\n", "# TODO: define your image paths here to be fed into the model\n", "#################################################################\n", "root_dir = 'notebooks/images'\n", "ims = ['866', '1287', '2468']\n", "im_paths = [os.path.join(root_dir, im) + '.jpg' for im in ims]\n", "\n", "# NOTE: Please make sure the images are pre-aligned!\n", "\n", "target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0]\n", "age_transformers = [AgeTransformer(target_age=age) for age in target_ages]\n", "\n", "for image_path in im_paths:\n", " image_name = os.path.basename(image_path)\n", " print(f'Working on image: {image_name}')\n", " original_image = Image.open(image_path).convert(\"RGB\")\n", " input_image = img_transforms(original_image)\n", " all_vecs = []\n", " for idx, age_transformer in enumerate(age_transformers):\n", "\n", " input_age_batch = [age_transformer(input_image.cpu()).to('cuda')]\n", " input_age_batch = torch.stack(input_age_batch)\n", "\n", " # get latent vector for the current target age amount\n", " with torch.no_grad():\n", " result_vec = run_on_batch_to_vecs(input_age_batch, net)\n", " result_image = get_result_from_vecs([result_vec], result_vec, 0)[0]\n", " all_vecs.append([result_vec])\n", "\n", " images = []\n", " for i in range(1, len(target_ages)):\n", " alpha_vals = np.linspace(0, 1, n_transition).tolist()\n", " for alpha in tqdm(alpha_vals):\n", " result_image = get_result_from_vecs(all_vecs[i-1], all_vecs[i], alpha)[0]\n", " output_im = tensor2im(result_image)\n", " images.append(np.array(output_im))\n", "\n", " animation_path = os.path.join(save_path, f\"{image_name}_animation\")\n", " generate_mp4(animation_path, images, kwargs)\n", " show_mp4(animation_path)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "inference_playground_mp4.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 1 }