inversion_testing / gradio_wrapper /gradio_options.py
ethanNeuralImage's picture
fix device to be cpu if cuda isn't availible
5e2052f
raw
history blame contribute delete
No virus
3.34 kB
import sys
import os
sys.path.append(".")
sys.path.append("..")
from argparse import ArgumentParser
class GradioTestOptions:
def __init__(self):
self.parser = ArgumentParser()
self.initialize()
def initialize(self):
# arguments for inference script
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint')
self.parser.add_argument('--device', default=None, type=str, help='device to use')
self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
self.parser.add_argument('--use_weight_delta_mapper', default=False, action="store_true")
self.parser.add_argument('--stylegan_size', default=1024, type=int)
self.parser.add_argument('--alpha', default=4.1, type=float, help='Alpha to use for weight delta')
self.parser.add_argument('--beta', default=0.14, type=float, help='Beta to use for weight delta')
self.parser.add_argument('--edit_weight_delta', default=False, action='store_true', help='Edit the Weight Delta in addition')
self.parser.add_argument('--weight_delta_alpha', default=4.1, type=float, help='Alpha to use for weight delta')
self.parser.add_argument('--weight_delta_beta', default=0.14, type=float, help='Beta to use for weight delta')
self.parser.add_argument("--delta_i_c", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/fs3.npy', help="path to file containing delta_i_c")
self.parser.add_argument("--s_statistics", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/S_mean_std', help="path to file containing s statistics")
self.parser.add_argument("--text_prompt_templates", default='./hyperstyle_global_directions/global_directions/templates.txt')
self.parser.add_argument("--neutral_text", type=str, default="A face with hair")
self.parser.add_argument("--target_text", type=str, default=None)
#arguments for hyperstyle
self.parser.add_argument('--hyperstyle_checkpoint_path', default='./pretrained_models/hyperstyle/hyperstyle_ffhq.pt', type=str, help='Path to HyperStyle model checkpoint')
self.parser.add_argument('--resize_outputs', action='store_true', help='Whether to resize outputs to 256x256 or keep at original output resolution')
# arguments for loading pre-trained encoder
self.parser.add_argument('--load_w_encoder', action='store_true', help='Whether to load the w e4e encoder.')
self.parser.add_argument('--w_encoder_checkpoint_path', default='./pretrained_models/hyperstyle/faces_w_encoder.pt', type=str, help='Path to pre-trained W-encoder.')
self.parser.add_argument('--w_encoder_type', default='WEncoder', help='Encoder type for the encoder used to get the initial inversion')
# arguments for iterative inference
self.parser.add_argument('--n_iters_per_batch', default=5, type=int, help='Number of forward passes per batch during training.')
#arguments to test dataset
self.parser.add_argument('--work_in_stylespace', default=False, action='store_true')
def parse(self, args=None):
opts = self.parser.parse_args(args)
return opts