import argparse |
import copy |
import warnings |
import tensorflow as tf |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
warnings.filterwarnings('ignore', category=FutureWarning) |
warnings.filterwarnings('ignore', category=DeprecationWarning) |
import sys, getopt, os |
import numpy as np |
import dnnlib |
from dnnlib import EasyDict |
import dnnlib.tflib as tflib |
from dnnlib.tflib import tfutil |
from dnnlib.tflib.autosummary import autosummary |
from training import misc |
import pickle |
def create_model(config_id = 'config-f', gamma = None, height = 512, width = 512, cond = None, label_size = 0): |
train = EasyDict(run_func_name='training.diagnostic.create_initial_pkl') |
G = EasyDict(func_name='training.networks_stylegan2.G_main') |
D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') |
D_loss = EasyDict(func_name='training.loss.D_logistic_r1') |
sched = EasyDict() |
sc = dnnlib.SubmitConfig() |
tf_config = {'rnd.np_random_seed': 1000} |
sched.minibatch_size_base = 192 |
sched.minibatch_gpu_base = 3 |
D_loss.gamma = 10 |
desc = 'stylegan2' |
dataset_args = EasyDict() |
if cond: |
desc += '-cond'; dataset_args.max_label_size = 'full' |
desc += '-' + config_id |
if config_id != 'config-f': |
G.fmap_base = D.fmap_base = 8 << 10 |
if config_id.startswith('config-e'): |
D_loss.gamma = 100 |
if 'Gorig' in config_id: G.architecture = 'orig' |
if 'Gskip' in config_id: G.architecture = 'skip' |
if 'Gresnet' in config_id: G.architecture = 'resnet' |
if 'Dorig' in config_id: D.architecture = 'orig' |
if 'Dskip' in config_id: D.architecture = 'skip' |
if 'Dresnet' in config_id: D.architecture = 'resnet' |
if config_id in ['config-a', 'config-b', 'config-c', 'config-d']: |
sched.lod_initial_resolution = 8 |
sched.G_lrate_base = sched.D_lrate_base = 0.001 |
sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} |
sched.minibatch_size_base = 32 |
sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32} |
sched.minibatch_gpu_base = 4 |
sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4} |
G.synthesis_func = 'G_synthesis_stylegan_revised' |
D.func_name = 'training.networks_stylegan2.D_stylegan' |
if config_id in ['config-a', 'config-b', 'config-c']: |
G_loss = EasyDict(func_name='training.loss.G_logistic_ns') |
if config_id in ['config-a', 'config-b']: |
train.lazy_regularization = False |
if config_id == 'config-a': |
G = EasyDict(func_name='training.networks_stylegan.G_style') |
D = EasyDict(func_name='training.networks_stylegan.D_basic') |
if gamma is not None: |
D_loss.gamma = gamma |
G.update(resolution_h=height) |
G.update(resolution_w=width) |
D.update(resolution_h=height) |
D.update(resolution_w=width) |
sc.submit_target = dnnlib.SubmitTarget.DIAGNOSTIC |
sc.local.do_not_copy_source_files = True |
kwargs = EasyDict(train) |
kwargs.update(G_args=G, D_args=D, tf_config=tf_config, config_id=config_id, |
resolution_h=height, resolution_w=width, label_size = label_size) |
kwargs.submit_config = copy.deepcopy(sc) |
kwargs.submit_config.run_desc = desc |
dnnlib.submit_diagnostic(**kwargs) |
return f'network-initial-config-f-{height}x{width}-{label_size}.pkl' |
def _str_to_bool(v): |
if isinstance(v, bool): |
return v |
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
return True |
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
return False |
else: |
raise argparse.ArgumentTypeError('Boolean value expected.') |
def _parse_comma_sep(s): |
if s is None or s.lower() == 'none' or s == '': |
return [] |
return s.split(',') |
def copy_weights(source_pkl, target_pkl, output_pkl): |
tflib.init_tf() |
with tf.Session(): |
with tf.device('/gpu:0'): |
sourceG, sourceD, sourceGs = pickle.load(open(source_pkl, 'rb')) |
targetG, targetD, targetGs = pickle.load(open(target_pkl, 'rb')) |
targetG.copy_compatible_trainables_from(sourceG) |
targetD.copy_compatible_trainables_from(sourceD) |
targetGs.copy_compatible_trainables_from(sourceGs) |
with open(os.path.join('./', output_pkl), 'wb') as file: |
pickle.dump((targetG, targetD, targetGs), file, protocol=pickle.HIGHEST_PROTOCOL) |