CharacterGAN / app.py
mfrashad's picture
Remove ipywidgets
b4a5ea4
import nltk; nltk.download('wordnet')
#@title Load Model
selected_model = 'character'
# Load model
import torch
import PIL
import numpy as np
from PIL import Image
from models import get_instrumented_model
from decomposition import get_or_compute
from config import Config
import gradio as gr
import numpy as np
# Speed up computation
torch.autograd.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
# Specify model to use
config = Config(
model='StyleGAN2',
layer='style',
output_class=selected_model,
components=80,
use_w=True,
batch_size=5_000, # style layer quite small
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
inst = get_instrumented_model(config.model, config.output_class,
config.layer, torch.device(device), use_w=config.use_w)
path_to_components = get_or_compute(config, inst)
model = inst.model
comps = np.load(path_to_components)
lst = comps.files
latent_dirs = []
latent_stdevs = []
load_activations = False
for item in lst:
if load_activations:
if item == 'act_comp':
for i in range(comps[item].shape[0]):
latent_dirs.append(comps[item][i])
if item == 'act_stdev':
for i in range(comps[item].shape[0]):
latent_stdevs.append(comps[item][i])
else:
if item == 'lat_comp':
for i in range(comps[item].shape[0]):
latent_dirs.append(comps[item][i])
if item == 'lat_stdev':
for i in range(comps[item].shape[0]):
latent_stdevs.append(comps[item][i])
def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):
# blockPrint()
model.truncation = truncation
if w is None:
w = model.sample_latent(1, seed=seed).detach().cpu().numpy()
w = [w]*model.get_max_latents() # one per layer
else:
w = [np.expand_dims(x, 0) for x in w]
for l in range(start, end):
for i in range(len(directions)):
w[l] = w[l] + directions[i] * distances[i] * scale
torch.cuda.empty_cache()
#save image and display
out = model.sample_np(w)
final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)
if save is not None:
if disp == False:
print(save)
final_im.save(f'out/{seed}_{save:05}.png')
return final_im
#@title Demo UI
def generate_image(seed, truncation,
monster, female, skimpy, light, bodysuit, bulky, human_head,
start_layer, end_layer):
seed = hash(seed) % 1000000000
scale = 1
params = {'monster': monster,
'female': female,
'skimpy': skimpy,
'light': light,
'bodysuit': bodysuit,
'bulky': bulky,
'human_head': human_head}
param_indexes = {'monster': 0,
'female': 1,
'skimpy': 2,
'light': 4,
'bodysuit': 5,
'bulky': 6,
'human_head': 8}
directions = []
distances = []
for k, v in params.items():
directions.append(latent_dirs[param_indexes[k]])
distances.append(v)
style = {'description_width': 'initial'}
return display_sample_pytorch(int(seed), truncation, directions, distances, scale, int(start_layer), int(end_layer), disp=False)
truncation = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label="Truncation")
start_layer = gr.inputs.Number(default=0, label="Start Layer")
end_layer = gr.inputs.Number(default=14, label="End Layer")
seed = gr.inputs.Textbox(default="0", label="Seed")
slider_max_val = 20
slider_min_val = -20
slider_step = 1
monster = gr.inputs.Slider(label="Monsterfication", minimum=slider_min_val, maximum=slider_max_val, default=0)
female = gr.inputs.Slider(label="Gender", minimum=slider_min_val, maximum=slider_max_val, default=0)
skimpy = gr.inputs.Slider(label="Amount of Clothing", minimum=slider_min_val, maximum=slider_max_val, default=0)
light = gr.inputs.Slider(label="Brightness", minimum=slider_min_val, maximum=slider_max_val, default=0)
bodysuit = gr.inputs.Slider(label="Bodysuit", minimum=slider_min_val, maximum=slider_max_val, default=0)
bulky = gr.inputs.Slider(label="Bulkiness", minimum=slider_min_val, maximum=slider_max_val, default=0)
human_head = gr.inputs.Slider(label="Head", minimum=slider_min_val, maximum=slider_max_val, default=0)
scale = 1
inputs = [seed, truncation, monster, female, skimpy, light, bodysuit, bulky, human_head, start_layer, end_layer]
description = "Change the seed number to generate different character design. Made by <a href='https://www.mfrashad.com/' target='_blank'>@mfrashad</a>. For more details on how to build this, visit the <a href='https://github.com/mfrashad/gancreate-saai' target='_blank'>repo</a>. Please give a star if you find it useful :)"
gr.Interface(generate_image, inputs, ["image"], description=description, live=True, title="CharacterGAN").launch()