#!/usr/bin/env python from __future__ import annotations import functools import os import pathlib import shlex import subprocess import sys import tarfile import gradio as gr import huggingface_hub import numpy as np import PIL.Image import torch if os.getenv("SYSTEM") == "spaces": with open("patch") as f: subprocess.run(shlex.split("patch -p1"), cwd="gan-control", stdin=f) sys.path.insert(0, "gan-control/src") from gan_control.inference.controller import Controller TITLE = "GAN-Control" DESCRIPTION = "https://github.com/amazon-research/gan-control" def download_models() -> None: model_dir = pathlib.Path("controller_age015id025exp02hai04ori02gam15") if not model_dir.exists(): path = huggingface_hub.hf_hub_download( "public-data/gan-control", "controller_age015id025exp02hai04ori02gam15.tar.gz" ) with tarfile.open(path) as f: f.extractall() @torch.inference_mode() def run( seed: int, truncation: float, yaw: int, pitch: int, age: int, hair_color_r: float, hair_color_g: float, hair_color_b: float, nrows: int, ncols: int, controller: Controller, device: torch.device, ) -> PIL.Image.Image: seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max)) batch_size = nrows * ncols latent_size = controller.config.model_config["latent_size"] latent = torch.from_numpy(np.random.RandomState(seed).randn(batch_size, latent_size)).float().to(device) initial_image_tensors, initial_latent_z, initial_latent_w = controller.gen_batch( latent=latent, truncation=truncation ) res0 = controller.make_resized_grid_image(initial_image_tensors, nrow=ncols) pose_control = torch.tensor([[yaw, pitch, 0]], dtype=torch.float32) image_tensors, _, modified_latent_w = controller.gen_batch_by_controls( latent=initial_latent_w, input_is_latent=True, orientation=pose_control ) res1 = controller.make_resized_grid_image(image_tensors, nrow=ncols) age_control = torch.tensor([[age]], dtype=torch.float32) image_tensors, _, modified_latent_w = controller.gen_batch_by_controls( latent=initial_latent_w, input_is_latent=True, age=age_control ) res2 = controller.make_resized_grid_image(image_tensors, nrow=ncols) hair_color = torch.tensor([[hair_color_r, hair_color_g, hair_color_b]], dtype=torch.float32) / 255 hair_color = torch.clamp(hair_color, 0, 1) image_tensors, _, modified_latent_w = controller.gen_batch_by_controls( latent=initial_latent_w, input_is_latent=True, hair=hair_color ) res3 = controller.make_resized_grid_image(image_tensors, nrow=ncols) return res0, res1, res2, res3 download_models() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") path = "controller_age015id025exp02hai04ori02gam15/" controller = Controller(path, device) fn = functools.partial(run, controller=controller, device=device) demo = gr.Interface( fn=fn, inputs=[ gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=0), gr.Slider(label="Truncation", minimum=0, maximum=1, step=0.1, value=0.7), gr.Slider(label="Yaw", minimum=-90, maximum=90, step=1, value=30), gr.Slider(label="Pitch", minimum=-90, maximum=90, step=1, value=0), gr.Slider(label="Age", minimum=15, maximum=75, step=1, value=75), gr.Slider(label="Hair Color (R)", minimum=0, maximum=255, step=1, value=186), gr.Slider(label="Hair Color (G)", minimum=0, maximum=255, step=1, value=158), gr.Slider(label="Hair Color (B)", minimum=0, maximum=255, step=1, value=92), gr.Slider(label="Number of Rows", minimum=1, maximum=3, step=1, value=1), gr.Slider(label="Number of Columns", minimum=1, maximum=5, step=1, value=5), ], outputs=[ gr.Image(label="Generated Image"), gr.Image(label="Head Pose Controlled"), gr.Image(label="Age Controlled"), gr.Image(label="Hair Color Controlled"), ], title=TITLE, description=DESCRIPTION, ) if __name__ == "__main__": demo.queue(max_size=10).launch()