{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "5OqEP1SlGeVZ" }, "source": [ "# SAM: Inference Playground" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dE2hzjSNQs0p" }, "outputs": [], "source": [ "import os\n", "os.chdir('/content')\n", "CODE_DIR = 'SAM'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bbaMZ40hQxT0", "outputId": "f7fac42a-77e7-4b79-ab87-b8805a4b8f39" }, "outputs": [], "source": [ "!git clone https://github.com/yuval-alaluf/SAM.git $CODE_DIR" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "43F-3KfeQ08S", "outputId": "f1def785-f7aa-4016-c6f7-afc2463d6b06" }, "outputs": [], "source": [ "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n", "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n", "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "av0207x4Q2iL" }, "outputs": [], "source": [ "os.chdir(f'./{CODE_DIR}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zvwx9NsiQq9t" }, "outputs": [], "source": [ "from argparse import Namespace\n", "import os\n", "import sys\n", "import pprint\n", "import numpy as np\n", "from PIL import Image\n", "import torch\n", "import torchvision.transforms as transforms\n", "\n", "sys.path.append(\".\")\n", "sys.path.append(\"..\")\n", "\n", "from datasets.augmentations import AgeTransformer\n", "from utils.common import tensor2im\n", "from models.psp import pSp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uj3dJjQsQq9y" }, "outputs": [], "source": [ "EXPERIMENT_TYPE = 'ffhq_aging'" ] }, { "cell_type": "markdown", "metadata": { "id": "mStxrAtuQq9y" }, "source": [ "## Step 1: Download Pretrained Model\n", "As part of this repository, we provide our pretrained aging model.\n", "We'll download the model for the selected experiments as save it to the folder `../pretrained_models`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_pC38oLGQq9z" }, "outputs": [], "source": [ "def get_download_model_command(file_id, file_name):\n", " \"\"\" Get wget download command for downloading the desired model and save to directory ../pretrained_models. \"\"\"\n", " current_directory = os.getcwd()\n", " save_path = os.path.join(os.path.dirname(current_directory), \"pretrained_models\")\n", " if not os.path.exists(save_path):\n", " os.makedirs(save_path)\n", " url = r\"\"\"wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={FILE_ID}\" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt\"\"\".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)\n", " return url " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rOQ2Vz2kQq9z" }, "outputs": [], "source": [ "MODEL_PATHS = {\n", " \"ffhq_aging\": {\"id\": \"1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC\", \"name\": \"sam_ffhq_aging.pt\"}\n", "}\n", "\n", "path = MODEL_PATHS[EXPERIMENT_TYPE]\n", "download_command = get_download_model_command(file_id=path[\"id\"], file_name=path[\"name\"]) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "K0nHPvo5Qq9z", "outputId": "3ac7ce05-077a-4d81-b6ca-0e5b2dc61753" }, "outputs": [], "source": [ "!wget {download_command}" ] }, { "cell_type": "markdown", "metadata": { "id": "WvRDiRrMQq90" }, "source": [ "## Step 2: Define Inference Parameters" ] }, { "cell_type": "markdown", "metadata": { "id": "GNaSSzZsQq90" }, "source": [ "Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the\n", "image to perform inference on.\n", "While we provide default values to run this script, feel free to change as needed." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yaGqalwuQq90" }, "outputs": [], "source": [ "EXPERIMENT_DATA_ARGS = {\n", " \"ffhq_aging\": {\n", " \"model_path\": \"../pretrained_models/sam_ffhq_aging.pt\",\n", " \"image_path\": \"notebooks/images/866.jpg\",\n", " \"transform\": transforms.Compose([\n", " transforms.Resize((256, 256)),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wjkLqLkDQq90" }, "outputs": [], "source": [ "EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]" ] }, { "cell_type": "markdown", "metadata": { "id": "YkfqoKJwQq91" }, "source": [ "## Step 3: Load Pretrained Model\n", "We assume that you have downloaded the pretrained aging model and placed it in the path defined above" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cZuho98JQq91" }, "outputs": [], "source": [ "model_path = EXPERIMENT_ARGS['model_path']\n", "ckpt = torch.load(model_path, map_location='cpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f6NOxONxQq91", "outputId": "7eecdad5-0678-45d4-d416-898e3fce250d" }, "outputs": [], "source": [ "opts = ckpt['opts']\n", "pprint.pprint(opts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J6c93qE9Qq91" }, "outputs": [], "source": [ "# update the training options\n", "opts['checkpoint_path'] = model_path" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JRTfKFrkQq91", "outputId": "1ebe3ebb-d33f-4764-d88c-8ba0a66ce0a8" }, "outputs": [], "source": [ "opts = Namespace(**opts)\n", "net = pSp(opts)\n", "net.eval()\n", "net.cuda()\n", "print('Model successfully loaded!')" ] }, { "cell_type": "markdown", "metadata": { "id": "z6BegCirQq92" }, "source": [ "## Step 4: Visualize Input" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kc4Sr31TQq92" }, "outputs": [], "source": [ "image_path = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE][\"image_path\"]\n", "original_image = Image.open(image_path).convert(\"RGB\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 273 }, "id": "bKA9BO9_Qq92", "outputId": "51152c46-8c8d-4020-f343-dc01dd523084" }, "outputs": [], "source": [ "original_image.resize((256, 256))" ] }, { "cell_type": "markdown", "metadata": { "id": "u3a50tAcQq92" }, "source": [ "## Step 5: Perform Inference" ] }, { "cell_type": "markdown", "metadata": { "id": "o6oqf8JwzK0K" }, "source": [ "### Align Image\n", "\n", "Before running inference we'll run alignment on the input image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "y244_ejy9Drx", "outputId": "bb583763-1aa1-4745-95f5-4b7bb2f96715" }, "outputs": [], "source": [ "!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\n", "!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hJ9Ce1aYzmFF" }, "outputs": [], "source": [ "def run_alignment(image_path):\n", " import dlib\n", " from scripts.align_all_parallel import align_face\n", " predictor = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n", " aligned_image = align_face(filepath=image_path, predictor=predictor) \n", " print(\"Aligned image has shape: {}\".format(aligned_image.size))\n", " return aligned_image " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aTZcKMdK8y77", "outputId": "18d7a5da-9e98-4373-c296-727216406dd5" }, "outputs": [], "source": [ "aligned_image = run_alignment(image_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 273 }, "id": "hUBAfodh5PaM", "outputId": "81545ff1-4184-4a3a-d887-52ad9f71e24a" }, "outputs": [], "source": [ "aligned_image.resize((256, 256))" ] }, { "cell_type": "markdown", "metadata": { "id": "gMyoh4X1HYAS" }, "source": [ "### Run Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XkzQpi1aQq92" }, "outputs": [], "source": [ "img_transforms = EXPERIMENT_ARGS['transform']\n", "input_image = img_transforms(aligned_image)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lI7yWNPDQq92" }, "outputs": [], "source": [ "# we'll run the image on multiple target ages \n", "target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]\n", "age_transformers = [AgeTransformer(target_age=age) for age in target_ages]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kLP4pF-2Qq93" }, "outputs": [], "source": [ "def run_on_batch(inputs, net):\n", " result_batch = net(inputs.to(\"cuda\").float(), randomize_noise=False, resize=False)\n", " return result_batch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nfrY_gEEQq93", "outputId": "b3b8f881-424a-4eac-9e06-1598442f9c62" }, "outputs": [], "source": [ "# for each age transformed age, we'll concatenate the results to display them side-by-side\n", "results = np.array(aligned_image.resize((1024, 1024)))\n", "for age_transformer in age_transformers:\n", " print(f\"Running on target age: {age_transformer.target_age}\")\n", " with torch.no_grad():\n", " input_image_age = [age_transformer(input_image.cpu()).to('cuda')]\n", " input_image_age = torch.stack(input_image_age)\n", " result_tensor = run_on_batch(input_image_age, net)[0]\n", " result_image = tensor2im(result_tensor)\n", " results = np.concatenate([results, result_image], axis=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "IFgwfLTKQq93" }, "source": [ "### Visualize Result" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wpwyrv0iQq93" }, "outputs": [], "source": [ "results = Image.fromarray(results)\n", "results # this is a very large image (11*1024 x 1024) so it may take some time to display!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4sL7fHp9Qq93", "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# save image at full resolution\n", "results.save(\"notebooks/images/age_transformed_image.jpg\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "inference_playground.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 1 }