multimodalart's picture
Upload 57 files
8a09a62 verified
raw
history blame
3.86 kB
import argparse
from .constants import *
from .modules.models import HUNYUAN_DIT_CONFIG
def get_args(default_args=None):
parser = argparse.ArgumentParser()
# Basic
parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.")
parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.")
parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024],
help='Image size (h, w). If a single value is provided, the image will be treated to '
'(value, value).')
parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch",
help="Inference mode")
# HunYuan-DiT
parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2')
parser.add_argument("--norm", type=str, default="layer", help="Normalization layer type")
parser.add_argument("--load-key", type=str, choices=["ema", "module"], default="ema", help="Load model key for HunYuanDiT checkpoint.")
parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024],
help="Size condition used in sampling. 2 values are required for height and width. "
"If a single value is provided, the image will be treated to (value, value).")
parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")
# Prompt enhancement
parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.")
parser.add_argument("--no-enhance", dest="enhance", action="store_false")
parser.set_defaults(enhance=True)
# Diffusion
parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
parser.set_defaults(learn_sigma=True)
parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction",
help="Diffusion predict type")
parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear",
help="Noise schedule")
parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value")
parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value")
# Text condition
parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.")
parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.")
parser.add_argument("--negative", type=str, default=None, help="Negative prompt.")
# Acceleration
parser.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.")
parser.add_argument("--no-fp16", dest="use_fp16", action="store_false")
parser.set_defaults(use_fp16=True)
# Sampling
parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size")
parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler")
parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps")
parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.")
# App
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language")
args = parser.parse_args(default_args)
return args