Spaces:
Runtime error
Runtime error
import argparse | |
import os.path as osp | |
import yaml | |
import random | |
from easydict import EasyDict as edict | |
import numpy.random as npr | |
import torch | |
from utils import ( | |
edict_2_dict, | |
check_and_create_dir, | |
update) | |
import wandb | |
import warnings | |
warnings.filterwarnings("ignore") | |
from glob import glob | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="code/config/base.yaml") | |
parser.add_argument("--experiment", type=str, default="conformal_0.5_dist_pixel_100_kernel201") | |
parser.add_argument("--seed", type=int, default=0) | |
parser.add_argument('--log_dir', metavar='DIR', default="output") | |
parser.add_argument('--font', type=str, default="none", help="font name") | |
parser.add_argument('--semantic_concept', type=str, help="the semantic concept to insert") | |
parser.add_argument('--word', type=str, default="none", help="the text to work on") | |
parser.add_argument('--script', type=str, default="arabic", help="script") | |
parser.add_argument('--prompt_suffix', type=str, default="minimal flat 2d vector. lineal color. trending on artstation") | |
parser.add_argument('--optimized_letter', type=str, default="none", help="the letter in the word to optimize") | |
parser.add_argument('--batch_size', type=int, default=1) | |
parser.add_argument('--use_wandb', type=int, default=0) | |
parser.add_argument('--wandb_user', type=str, default="none") | |
cfg = edict() | |
args = parser.parse_args() | |
with open('TOKEN', 'r') as f: | |
setattr(args, 'token', f.read().replace('\n', '')) | |
cfg.config = args.config | |
cfg.experiment = args.experiment | |
cfg.seed = args.seed | |
cfg.font = args.font | |
cfg.semantic_concept = args.semantic_concept | |
cfg.word = cfg.semantic_concept if args.word == "none" else args.word | |
cfg.letter = cfg.word | |
cfg.script = args.script | |
script_path = f"code/data/fonts/{cfg.script}" | |
if cfg.font == "none": | |
cfg.font = osp.basename(glob(f"{script_path}/*.ttf")[0])[:-4] | |
# if " " in cfg.word: | |
# raise ValueError(f'no spaces are allowed') | |
if "jpeg" in args.semantic_concept: | |
cfg.caption = args.semantic_concept | |
else: | |
cfg.caption = f"a {args.semantic_concept}. {args.prompt_suffix}" | |
# cfg.log_dir = f"{args.log_dir}/{args.experiment}_{cfg.word}" | |
cfg.log_dir = f"{args.log_dir}/{cfg.script}" | |
if args.optimized_letter in cfg.word: | |
cfg.optimized_letter = args.optimized_letter | |
else: | |
raise ValueError(f'letter should be in word') | |
cfg.batch_size = args.batch_size | |
cfg.token = args.token | |
cfg.use_wandb = args.use_wandb | |
cfg.wandb_user = args.wandb_user | |
cfg.letter = f"{cfg.font}_{args.optimized_letter}_scaled" | |
cfg.target = f"code/data/init/{cfg.letter}" | |
if ' ' in cfg.target: | |
cfg.target = cfg.target.replace(' ', '_') | |
return cfg | |
def set_config(): | |
cfg_arg = parse_args() | |
with open(cfg_arg.config, 'r') as f: | |
cfg_full = yaml.load(f, Loader=yaml.FullLoader) | |
# recursively traverse parent_config pointers in the config dicts | |
cfg_key = cfg_arg.experiment | |
cfgs = [cfg_arg] | |
while cfg_key: | |
cfgs.append(cfg_full[cfg_key]) | |
cfg_key = cfgs[-1].get('parent_config', 'baseline') | |
# allowing children configs to override their parents | |
cfg = edict() | |
for options in reversed(cfgs): | |
update(cfg, options) | |
del cfgs | |
# set experiment dir | |
signature = f"{cfg.word}_{cfg.semantic_concept}_{cfg.seed}" | |
cfg.experiment_dir = osp.join(cfg.log_dir, signature) | |
configfile = osp.join(cfg.experiment_dir, 'config.yaml') | |
print('Config:', cfg) | |
# create experiment dir and save config | |
check_and_create_dir(configfile) | |
with open(osp.join(configfile), 'w') as f: | |
yaml.dump(edict_2_dict(cfg), f) | |
if cfg.use_wandb: | |
wandb.init(project="Font-To-Image", entity=cfg.wandb_user, | |
config=cfg, name=f"{signature}", id=wandb.util.generate_id()) | |
if cfg.seed is not None: | |
random.seed(cfg.seed) | |
npr.seed(cfg.seed) | |
torch.manual_seed(cfg.seed) | |
torch.backends.cudnn.benchmark = False | |
else: | |
assert False | |
return cfg | |