FashionGAN / app.py
fiesty-bear's picture
changed rl to ril
6d11e12 verified
raw
history blame
4.4 kB
selected_model = 'lookbook' #@param {type:"string"}
# Load model
import torch
import numpy as np
from PIL import Image
from models import get_instrumented_model
from decomposition import get_or_compute
from config import Config
# Speed up computation
torch.autograd.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
# Specify model to use
config = Config(
model='StyleGAN2',
layer='style',
output_class=selected_model,
components=20,
use_w=True,
batch_size=5_000, # style layer quite small
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inst = get_instrumented_model(config.model, config.output_class,
config.layer, device, use_w=config.use_w)
path_to_components = get_or_compute(config, inst)
model = inst.model
comps = np.load(path_to_components)
lst = comps.files
latent_dirs = []
latent_stdevs = []
load_activations = False
for item in lst:
if load_activations:
if item == 'act_comp':
for i in range(comps[item].shape[0]):
latent_dirs.append(comps[item][i])
if item == 'act_stdev':
for i in range(comps[item].shape[0]):
latent_stdevs.append(comps[item][i])
else:
if item == 'lat_comp':
for i in range(comps[item].shape[0]):
latent_dirs.append(comps[item][i])
if item == 'lat_stdev':
for i in range(comps[item].shape[0]):
latent_stdevs.append(comps[item][i])
#@title Define functions
def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):
# blockPrint()
model.truncation = truncation
if w is None:
w = model.sample_latent(1, seed=seed).detach().cpu().numpy()
w = [w]*model.get_max_latents() # one per layer
else:
w = [np.expand_dims(x, 0) for x in w]
for l in range(start, end):
for i in range(len(directions)):
w[l] = w[l] + directions[i] * distances[i] * scale
torch.cuda.empty_cache()
#save image and display
out = model.sample_np(w)
final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)
if save is not None:
if disp == False:
print(save)
final_im.save(f'out/{seed}_{save:05}.png')
if disp:
display(final_im)
return final_im
#@title Demo UI
import gradio as gr
import numpy as np
gr.themes.Glass()
def generate_image(seed=0, c0=0, c1=0, c2=0, c3=0, c4=0, c5=0, c6=0):
seed = int(seed)
params = {'c0': -c0,
'c1': c1,
'c2': c2,
'c3': c3,
'c4': c4,
'c5': c5,
'c6': c6}
# Assigns slider to the principal components
param_indexes = {'c0': 12,
'c1': 6,
'c2': 7,
'c3': 2,
'c4': 11,
'c5': 9,
'c6': 10}
# Save the values from the sliders
directions = []
distances = []
for k, v in params.items():
directions.append(latent_dirs[param_indexes[k]])
distances.append(v)
# Additional settings for image generation
start_layer = 0
end_layer = 14
truncation = 0.5
return display_sample_pytorch(seed, truncation, directions, distances, 1, int(start_layer), int(end_layer), disp=False)
# Create a number input for seed
seed = gr.Number(value=6, label="Seed 1")
slider_max_val = 5
slider_min_val = -5
slider_step = 0.1
# Create the sliders input
c0 = gr.Slider(label="Design Pattern", minimum=slider_min_val, maximum=slider_max_val, value=0)
c1 = gr.Slider(label="Traditional", minimum=slider_min_val, maximum=slider_max_val, value=0)
c2 = gr.Slider(label="Darker Tone", minimum=slider_min_val, maximum=slider_max_val, value=0)
c3 = gr.Slider(label="Neck Line", minimum=slider_min_val, maximum=slider_max_val, value=0)
c4 = gr.Slider(label="Graphics", minimum=slider_min_val, maximum=slider_max_val, value=0)
c5 = gr.Slider(label="Darker Tone", minimum=slider_min_val, maximum=slider_max_val, value=0)
c6 = gr.Slider(label="Greenish", minimum=slider_min_val, maximum=slider_max_val, value=0)
inputs = [seed, c0, c1, c2, c3]
# Launch demo
gr.Interface(generate_image, inputs, ["image"], live=True, title="Fashion GAN", description="StyleGan2+SpaceGan to generate parameter controlled images. With ❤ by TCS Rapid Innovation Labs").launch(debug=True, share=True)