inversion_testing / gradio_wrapper /gradio_options.py
ethanNeuralImage's picture
fix device to be cpu if cuda isn't availible
5e2052f
raw
history blame
No virus
3.34 kB
import sys
import os
sys.path.append(".")
sys.path.append("..")
from argparse import ArgumentParser
class GradioTestOptions:
def __init__(self):
self.parser = ArgumentParser()
self.initialize()
def initialize(self):
# arguments for inference script
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint')
self.parser.add_argument('--device', default=None, type=str, help='device to use')
self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
self.parser.add_argument('--use_weight_delta_mapper', default=False, action="store_true")
self.parser.add_argument('--stylegan_size', default=1024, type=int)
self.parser.add_argument('--alpha', default=4.1, type=float, help='Alpha to use for weight delta')
self.parser.add_argument('--beta', default=0.14, type=float, help='Beta to use for weight delta')
self.parser.add_argument('--edit_weight_delta', default=False, action='store_true', help='Edit the Weight Delta in addition')
self.parser.add_argument('--weight_delta_alpha', default=4.1, type=float, help='Alpha to use for weight delta')
self.parser.add_argument('--weight_delta_beta', default=0.14, type=float, help='Beta to use for weight delta')
self.parser.add_argument("--delta_i_c", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/fs3.npy', help="path to file containing delta_i_c")
self.parser.add_argument("--s_statistics", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/S_mean_std', help="path to file containing s statistics")
self.parser.add_argument("--text_prompt_templates", default='./hyperstyle_global_directions/global_directions/templates.txt')
self.parser.add_argument("--neutral_text", type=str, default="A face with hair")
self.parser.add_argument("--target_text", type=str, default=None)
#arguments for hyperstyle
self.parser.add_argument('--hyperstyle_checkpoint_path', default='./pretrained_models/hyperstyle/hyperstyle_ffhq.pt', type=str, help='Path to HyperStyle model checkpoint')
self.parser.add_argument('--resize_outputs', action='store_true', help='Whether to resize outputs to 256x256 or keep at original output resolution')
# arguments for loading pre-trained encoder
self.parser.add_argument('--load_w_encoder', action='store_true', help='Whether to load the w e4e encoder.')
self.parser.add_argument('--w_encoder_checkpoint_path', default='./pretrained_models/hyperstyle/faces_w_encoder.pt', type=str, help='Path to pre-trained W-encoder.')
self.parser.add_argument('--w_encoder_type', default='WEncoder', help='Encoder type for the encoder used to get the initial inversion')
# arguments for iterative inference
self.parser.add_argument('--n_iters_per_batch', default=5, type=int, help='Number of forward passes per batch during training.')
#arguments to test dataset
self.parser.add_argument('--work_in_stylespace', default=False, action='store_true')
def parse(self, args=None):
opts = self.parser.parse_args(args)
return opts