|
|
|
from utils import download_url |
|
import argparse |
|
import numpy as np |
|
import PIL.Image |
|
import dnnlib |
|
import dnnlib.tflib as tflib |
|
import re |
|
import sys |
|
from io import BytesIO |
|
import IPython.display |
|
from math import ceil |
|
from PIL import Image, ImageDraw |
|
import os |
|
import pickle |
|
from utils import log_progress, imshow, create_image_grid, show_animation |
|
import imageio |
|
import glob |
|
import gdown |
|
|
|
class Rasm: |
|
|
|
def __init__(self, mode = 'calligraphy'): |
|
|
|
if mode == 'calligraphy': |
|
url = 'https://drive.google.com/uc?id=138fdURGxdkOwZq7IWvnrGLcfo5VI8O1R' |
|
|
|
else: |
|
url = 'https://drive.google.com/uc?id=13h-alXGI0hbNOJy1qbmeoroXZSPBHEG2' |
|
|
|
output = 'model.pkl' |
|
print('Downloading networks from "%s"...' %url) |
|
gdown.download(url, output, quiet=False) |
|
dnnlib.tflib.init_tf() |
|
with dnnlib.util.open_url(output) as fp: |
|
self._G, self._D, self.Gs = pickle.load(fp) |
|
self.noise_vars = [var for name, var in self.Gs.components.synthesis.vars.items() if name.startswith('noise')] |
|
|
|
|
|
def generate_images_in_w_space(self, dlatents, truncation_psi): |
|
Gs_kwargs = dnnlib.EasyDict() |
|
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
|
Gs_kwargs.randomize_noise = False |
|
Gs_kwargs.truncation_psi = truncation_psi |
|
|
|
|
|
imgs = [] |
|
for _, dlatent in log_progress(enumerate(dlatents), name = "Generating images"): |
|
|
|
|
|
row_images = self.Gs.components.synthesis.run(dlatent, **Gs_kwargs) |
|
imgs.append(PIL.Image.fromarray(row_images[0], 'RGB')) |
|
return imgs |
|
|
|
def generate_images(self, zs, truncation_psi, class_idx = None): |
|
Gs_kwargs = dnnlib.EasyDict() |
|
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
|
Gs_kwargs.randomize_noise = False |
|
if not isinstance(truncation_psi, list): |
|
truncation_psi = [truncation_psi] * len(zs) |
|
|
|
imgs = [] |
|
label = np.zeros([1] + self.Gs.input_shapes[1][1:]) |
|
if class_idx is not None: |
|
label[:, class_idx] = 1 |
|
else: |
|
label = None |
|
for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"): |
|
Gs_kwargs.truncation_psi = truncation_psi[z_idx] |
|
noise_rnd = np.random.RandomState(1) |
|
tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) |
|
images = self.Gs.run(z, label, **Gs_kwargs) |
|
imgs.append(PIL.Image.fromarray(images[0], 'RGB')) |
|
return imgs |
|
|
|
def generate_from_zs(self, zs, truncation_psi = 0.5): |
|
Gs_kwargs = dnnlib.EasyDict() |
|
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
|
Gs_kwargs.randomize_noise = False |
|
if not isinstance(truncation_psi, list): |
|
truncation_psi = [truncation_psi] * len(zs) |
|
|
|
for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"): |
|
Gs_kwargs.truncation_psi = truncation_psi[z_idx] |
|
noise_rnd = np.random.RandomState(1) |
|
tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) |
|
images = self.Gs.run(z, None, **Gs_kwargs) |
|
img = PIL.Image.fromarray(images[0], 'RGB') |
|
imshow(img) |
|
|
|
def generate_random_zs(self, size): |
|
seeds = np.random.randint(2**32, size=size) |
|
zs = [] |
|
for _, seed in enumerate(seeds): |
|
rnd = np.random.RandomState(seed) |
|
z = rnd.randn(1, *self.Gs.input_shape[1:]) |
|
zs.append(z) |
|
return zs |
|
|
|
|
|
def generate_zs_from_seeds(self, seeds): |
|
zs = [] |
|
for _, seed in enumerate(seeds): |
|
rnd = np.random.RandomState(seed) |
|
z = rnd.randn(1, *self.Gs.input_shape[1:]) |
|
zs.append(z) |
|
return zs |
|
|
|
|
|
def generate_images_from_seeds(self, seeds, truncation_psi): |
|
return imshow(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi)[0]) |
|
|
|
def generate_randomly(self, truncation_psi = 0.5): |
|
return self.generate_images_from_seeds(np.random.randint(4294967295, size=1), truncation_psi=truncation_psi) |
|
|
|
def generate_grid(self, truncation_psi = 0.7): |
|
seeds = np.random.randint((2**32 - 1), size=9) |
|
return create_image_grid(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi), 0.7 , 3) |
|
|
|
def generate_animation(self, size = 9, steps = 10, trunc_psi = 0.5): |
|
seeds = list(np.random.randint((2**32) - 1, size=size)) |
|
seeds = seeds + [seeds[0]] |
|
zs = self.generate_zs_from_seeds(seeds) |
|
|
|
imgs = self.generate_images(self.interpolate(zs, steps = steps), trunc_psi) |
|
movie_name = 'animation.mp4' |
|
with imageio.get_writer(movie_name, mode='I') as writer: |
|
for image in log_progress(list(imgs), name = "Creating animation"): |
|
writer.append_data(np.array(image)) |
|
return show_animation(movie_name) |
|
|
|
def convertZtoW(self, latent, truncation_psi=0.7, truncation_cutoff=9): |
|
dlatent = self.Gs.components.mapping.run(latent, None) |
|
dlatent_avg = self.Gs.get_var('dlatent_avg') |
|
for i in range(truncation_cutoff): |
|
dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg |
|
|
|
return dlatent |
|
|
|
def interpolate(self, zs, steps = 10): |
|
out = [] |
|
for i in range(len(zs)-1): |
|
for index in range(steps): |
|
fraction = index/float(steps) |
|
out.append(zs[i+1]*fraction + zs[i]*(1-fraction)) |
|
return out |
|
|