# SAM: Inference Playground

In [None]:
import os
os.chdir('/content')
CODE_DIR = 'SAM'

In [None]:
!git clone https://github.com/yuval-alaluf/SAM.git $CODE_DIR

In [None]:
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 

In [None]:
os.chdir(f'./{CODE_DIR}')

In [None]:
from argparse import Namespace
import os
import sys
import pprint
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from datasets.augmentations import AgeTransformer
from utils.common import tensor2im
from models.psp import pSp

In [None]:
EXPERIMENT_TYPE = 'ffhq_aging'

## Step 1: Download Pretrained Model
As part of this repository, we provide our pretrained aging model.
We'll download the model for the selected experiments as save it to the folder `../pretrained_models`.

In [None]:
def get_download_model_command(file_id, file_name):
 """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
 current_directory = os.getcwd()
 save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
 if not os.path.exists(save_path):
 os.makedirs(save_path)
 url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
 return url 

In [None]:
MODEL_PATHS = {
 "ffhq_aging": {"id": "1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC", "name": "sam_ffhq_aging.pt"}
}

path = MODEL_PATHS[EXPERIMENT_TYPE]
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) 

In [None]:
!wget {download_command}

## Step 2: Define Inference Parameters

Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the
image to perform inference on.
While we provide default values to run this script, feel free to change as needed.

In [None]:
EXPERIMENT_DATA_ARGS = {
 "ffhq_aging": {
 "model_path": "../pretrained_models/sam_ffhq_aging.pt",
 "image_path": "notebooks/images/866.jpg",
 "transform": transforms.Compose([
 transforms.Resize((256, 256)),
 transforms.ToTensor(),
 transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 }
}

In [None]:
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]

## Step 3: Load Pretrained Model
We assume that you have downloaded the pretrained aging model and placed it in the path defined above

In [None]:
model_path = EXPERIMENT_ARGS['model_path']
ckpt = torch.load(model_path, map_location='cpu')

In [None]:
opts = ckpt['opts']
pprint.pprint(opts)

In [None]:
# update the training options
opts['checkpoint_path'] = model_path

In [None]:
opts = Namespace(**opts)
net = pSp(opts)
net.eval()
net.cuda()
print('Model successfully loaded!')

## Step 4: Visualize Input

In [None]:
image_path = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]["image_path"]
original_image = Image.open(image_path).convert("RGB")

In [None]:
original_image.resize((256, 256))

## Step 5: Perform Inference

### Align Image

Before running inference we'll run alignment on the input image.

In [None]:
!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2

In [None]:
def run_alignment(image_path):
 import dlib
 from scripts.align_all_parallel import align_face
 predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
 aligned_image = align_face(filepath=image_path, predictor=predictor) 
 print("Aligned image has shape: {}".format(aligned_image.size))
 return aligned_image 

In [None]:
aligned_image = run_alignment(image_path)

In [None]:
aligned_image.resize((256, 256))

### Run Inference

In [None]:
img_transforms = EXPERIMENT_ARGS['transform']
input_image = img_transforms(aligned_image)

In [None]:
# we'll run the image on multiple target ages 
target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
age_transformers = [AgeTransformer(target_age=age) for age in target_ages]

In [None]:
def run_on_batch(inputs, net):
 result_batch = net(inputs.to("cuda").float(), randomize_noise=False, resize=False)
 return result_batch

In [None]:
# for each age transformed age, we'll concatenate the results to display them side-by-side
results = np.array(aligned_image.resize((1024, 1024)))
for age_transformer in age_transformers:
 print(f"Running on target age: {age_transformer.target_age}")
 with torch.no_grad():
 input_image_age = [age_transformer(input_image.cpu()).to('cuda')]
 input_image_age = torch.stack(input_image_age)
 result_tensor = run_on_batch(input_image_age, net)[0]
 result_image = tensor2im(result_tensor)
 results = np.concatenate([results, result_image], axis=1)

### Visualize Result

In [None]:
results = Image.fromarray(results)
results # this is a very large image (11*1024 x 1024) so it may take some time to display!

In [None]:
# save image at full resolution
results.save("notebooks/images/age_transformed_image.jpg")